From b832668768050fe4037f7261e356b22ac9783f3d Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Wed, 15 Jul 2020 11:25:39 +0300 Subject: [PATCH] Minprotobuf refactoring 2 (#269) * Protobuf refactoring stage II. * Remove NoError. * Change trace level for invalid message. --- libp2p/crypto/crypto.nim | 70 +++-- libp2p/multiaddress.nim | 47 ++-- libp2p/peerid.nim | 21 +- libp2p/protobuf/minprotobuf.nim | 315 ++++++++++++----------- libp2p/protocols/identify.nim | 74 +++--- libp2p/protocols/pubsub/pubsubpeer.nim | 18 +- libp2p/protocols/pubsub/rpc/protobuf.nim | 157 ++++++----- libp2p/protocols/secure/noise.nim | 6 +- tests/testminprotobuf.nim | 158 ++++++++---- 9 files changed, 473 insertions(+), 393 deletions(-) diff --git a/libp2p/crypto/crypto.nim b/libp2p/crypto/crypto.nim index 893a224..a6a291e 100644 --- a/libp2p/crypto/crypto.nim +++ b/libp2p/crypto/crypto.nim @@ -31,11 +31,6 @@ type ECDSA, NoSupport - CipherScheme* = enum - Aes128 = 0, - Aes256, - Blowfish - DigestSheme* = enum Sha256, Sha512 @@ -283,7 +278,9 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool = var buffer: seq[byte] if len(data) > 0: var pb = initProtoBuffer(@data) - if pb.getField(1, id) and pb.getField(2, buffer): + let r1 = pb.getField(1, id) + let r2 = pb.getField(2, buffer) + if r1.isOk() and r1.get() and r2.isOk() and r2.get(): if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0: var scheme = cast[PKScheme](cast[int8](id)) when key is PrivateKey: @@ -743,9 +740,15 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte], ## ## Procedure returns ``true`` on success and ``false`` on error. var pb = initProtoBuffer(message) - pb.getField(1, nonce) and pb.getField(2, pubkey) and - pb.getField(3, exchanges) and pb.getField(4, ciphers) and - pb.getField(5, hashes) + let r1 = pb.getField(1, nonce) + let r2 = pb.getField(2, pubkey) + let r3 = pb.getField(3, exchanges) + let r4 = pb.getField(4, ciphers) + let r5 = pb.getField(5, hashes) + + r1.isOk() and r1.get() and r2.isOk() and r2.get() and + r3.isOk() and r3.get() and r4.isOk() and r4.get() and + r5.isOk() and r5.get() proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] = ## Create SecIO exchange message using ephemeral public key ``epubkey`` and @@ -763,7 +766,9 @@ proc decodeExchange*(message: seq[byte], ## ## Procedure returns ``true`` on success and ``false`` on error. var pb = initProtoBuffer(message) - pb.getField(1, pubkey) and pb.getField(2, signature) + let r1 = pb.getField(1, pubkey) + let r2 = pb.getField(2, signature) + r1.isOk() and r1.get() and r2.isOk() and r2.get() ## Serialization/Deserialization helpers @@ -825,28 +830,37 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {. value = sig proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int, - value: var T): bool = + value: var T): ProtoResult[bool] = + ## Deserialize public/private key from protobuf's message ``pb`` using field + ## index ``field``. + ## + ## On success deserialized key will be stored in ``value``. var buffer: seq[byte] var key: T - if not(getField(pb, field, buffer)): - return false - if len(buffer) == 0: - return false - if key.init(buffer): - value = key - true + let res = ? pb.getField(field, buffer) + if not(res): + ok(false) else: - false + if key.init(buffer): + value = key + ok(true) + else: + err(ProtoError.IncorrectBlob) -proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool = +proc getField*(pb: ProtoBuffer, field: int, + value: var Signature): ProtoResult[bool] = + ## Deserialize signature from protobuf's message ``pb`` using field index + ## ``field``. + ## + ## On success deserialized signature will be stored in ``value``. var buffer: seq[byte] var sig: Signature - if not(getField(pb, field, buffer)): - return false - if len(buffer) == 0: - return false - if sig.init(buffer): - value = sig - true + let res = ? pb.getField(field, buffer) + if not(res): + ok(false) else: - false + if sig.init(buffer): + value = sig + ok(true) + else: + err(ProtoError.IncorrectBlob) diff --git a/libp2p/multiaddress.nim b/libp2p/multiaddress.nim index 3d6d85e..f07c2a2 100644 --- a/libp2p/multiaddress.nim +++ b/libp2p/multiaddress.nim @@ -1025,31 +1025,34 @@ proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} = write(pb, field, value.data.buffer) proc getField*(pb: var ProtoBuffer, field: int, - value: var MultiAddress): bool {.inline.} = + value: var MultiAddress): ProtoResult[bool] {. + inline.} = var buffer: seq[byte] - if not(getField(pb, field, buffer)): - return false - if len(buffer) == 0: - return false - let ma = MultiAddress.init(buffer) - if ma.isOk(): - value = ma.get() - true + let res = ? pb.getField(field, buffer) + if not(res): + ok(false) else: - false + let ma = MultiAddress.init(buffer) + if ma.isOk(): + value = ma.get() + ok(true) + else: + err(ProtoError.IncorrectBlob) proc getRepeatedField*(pb: var ProtoBuffer, field: int, - value: var seq[MultiAddress]): bool {.inline.} = + value: var seq[MultiAddress]): ProtoResult[bool] {. + inline.} = var items: seq[seq[byte]] value.setLen(0) - if not(getRepeatedField(pb, field, items)): - return false - if len(items) == 0: - return true - for item in items: - let ma = MultiAddress.init(item) - if ma.isOk(): - value.add(ma.get()) - else: - value.setLen(0) - return false + let res = ? pb.getRepeatedField(field, items) + if not(res): + ok(false) + else: + for item in items: + let ma = MultiAddress.init(item) + if ma.isOk(): + value.add(ma.get()) + else: + value.setLen(0) + return err(ProtoError.IncorrectBlob) + ok(true) diff --git a/libp2p/peerid.nim b/libp2p/peerid.nim index 5c81c66..9eeb467 100644 --- a/libp2p/peerid.nim +++ b/libp2p/peerid.nim @@ -219,16 +219,17 @@ proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) = ## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding. write(pb, field, pid.data) -proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool = +proc getField*(pb: ProtoBuffer, field: int, + pid: var PeerID): ProtoResult[bool] {.inline.} = ## Read ``PeerID`` from ProtoBuf's message and validate it var buffer: seq[byte] - var peerId: PeerID - if not(getField(pb, field, buffer)): - return false - if len(buffer) == 0: - return false - if peerId.init(buffer): - pid = peerId - true + let res = ? pb.getField(field, buffer) + if not(res): + ok(false) else: - false + var peerId: PeerID + if peerId.init(buffer): + pid = peerId + ok(true) + else: + err(ProtoError.IncorrectBlob) diff --git a/libp2p/protobuf/minprotobuf.nim b/libp2p/protobuf/minprotobuf.nim index 5a00c47..4476d07 100644 --- a/libp2p/protobuf/minprotobuf.nim +++ b/libp2p/protobuf/minprotobuf.nim @@ -11,7 +11,8 @@ {.push raises: [Defect].} -import ../varint, stew/endians2 +import ../varint, stew/[endians2, results] +export results const MaxMessageSize* = 1'u shl 22 @@ -51,12 +52,15 @@ type of StartGroup, EndGroup: discard - ProtoResult {.pure.} = enum - VarintDecodeError, - MessageIncompleteError, - BufferOverflowError, - MessageSizeTooBigError, - NoError + ProtoError* {.pure.} = enum + VarintDecode, + MessageIncomplete, + BufferOverflow, + MessageTooBig, + BadWireType, + IncorrectBlob + + ProtoResult*[T] = Result[T, ProtoError] ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 | hint | hint32 | hint64 | float32 | float64 @@ -361,7 +365,8 @@ proc finish*(pb: var ProtoBuffer) = else: pb.offset = 0 -proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool = +proc getHeader(data: var ProtoBuffer, + header: var ProtoHeader): ProtoResult[void] = var length = 0 var hdr = 0'u64 if PB.getUVarint(data.toOpenArray(), length, hdr).isOk(): @@ -370,34 +375,34 @@ proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool = if wire in SupportedWireTypes: data.offset += length header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire)) - true + ok() else: - false + err(ProtoError.BadWireType) else: - false + err(ProtoError.VarintDecode) -proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool = +proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] = case header.wire of ProtoFieldKind.Varint: var length = 0 var value = 0'u64 if PB.getUVarint(data.toOpenArray(), length, value).isOk(): data.offset += length - true + ok() else: - false + err(ProtoError.VarintDecode) of ProtoFieldKind.Fixed32: if data.isEnough(sizeof(uint32)): data.offset += sizeof(uint32) - true + ok() else: - false + err(ProtoError.VarintDecode) of ProtoFieldKind.Fixed64: if data.isEnough(sizeof(uint64)): data.offset += sizeof(uint64) - true + ok() else: - false + err(ProtoError.VarintDecode) of ProtoFieldKind.Length: var length = 0 var bsize = 0'u64 @@ -406,19 +411,19 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool = if bsize <= uint64(MaxMessageSize): if data.isEnough(int(bsize)): data.offset += int(bsize) - true + ok() else: - false + err(ProtoError.MessageIncomplete) else: - false + err(ProtoError.MessageTooBig) else: - false + err(ProtoError.VarintDecode) of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup: - false + err(ProtoError.BadWireType) proc getValue[T: ProtoScalar](data: var ProtoBuffer, header: ProtoHeader, - outval: var T): ProtoResult = + outval: var T): ProtoResult[void] = when (T is uint64) or (T is uint32) or (T is uint): doAssert(header.wire == ProtoFieldKind.Varint) var length = 0 @@ -426,9 +431,9 @@ proc getValue[T: ProtoScalar](data: var ProtoBuffer, if PB.getUVarint(data.toOpenArray(), length, value).isOk(): data.offset += length outval = value - ProtoResult.NoError + ok() else: - ProtoResult.VarintDecodeError + err(ProtoError.VarintDecode) elif (T is zint64) or (T is zint32) or (T is zint) or (T is hint64) or (T is hint32) or (T is hint): doAssert(header.wire == ProtoFieldKind.Varint) @@ -437,29 +442,29 @@ proc getValue[T: ProtoScalar](data: var ProtoBuffer, if getSVarint(data.toOpenArray(), length, value).isOk(): data.offset += length outval = value - ProtoResult.NoError + ok() else: - ProtoResult.VarintDecodeError + err(ProtoError.VarintDecode) elif T is float32: doAssert(header.wire == ProtoFieldKind.Fixed32) if data.isEnough(sizeof(float32)): outval = cast[float32](fromBytesLE(uint32, data.toOpenArray())) data.offset += sizeof(float32) - ProtoResult.NoError + ok() else: - ProtoResult.MessageIncompleteError + err(ProtoError.MessageIncomplete) elif T is float64: doAssert(header.wire == ProtoFieldKind.Fixed64) if data.isEnough(sizeof(float64)): outval = cast[float64](fromBytesLE(uint64, data.toOpenArray())) data.offset += sizeof(float64) - ProtoResult.NoError + ok() else: - ProtoResult.MessageIncompleteError + err(ProtoError.MessageIncomplete) proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader, outBytes: var openarray[T], - outLength: var int): ProtoResult = + outLength: var int): ProtoResult[void] = doAssert(header.wire == ProtoFieldKind.Length) var length = 0 var bsize = 0'u64 @@ -474,20 +479,20 @@ proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader, if bsize > 0'u64: copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) data.offset += int(bsize) - ProtoResult.NoError + ok() else: # Buffer overflow should not be critical failure data.offset += int(bsize) - ProtoResult.BufferOverflowError + err(ProtoError.BufferOverflow) else: - ProtoResult.MessageIncompleteError + err(ProtoError.MessageIncomplete) else: - ProtoResult.MessageSizeTooBigError + err(ProtoError.MessageTooBig) else: - ProtoResult.VarintDecodeError + err(ProtoError.VarintDecode) proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader, - outBytes: var T): ProtoResult = + outBytes: var T): ProtoResult[void] = doAssert(header.wire == ProtoFieldKind.Length) var length = 0 var bsize = 0'u64 @@ -501,27 +506,24 @@ proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader, if bsize > 0'u64: copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) data.offset += int(bsize) - ProtoResult.NoError + ok() else: - ProtoResult.MessageIncompleteError + err(ProtoError.MessageIncomplete) else: - ProtoResult.MessageSizeTooBigError + err(ProtoError.MessageTooBig) else: - ProtoResult.VarintDecodeError + err(ProtoError.VarintDecode) proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int, - output: var T): bool = + output: var T): ProtoResult[bool] = checkFieldNumber(field) - var value: T + var current: T var res = false var pb = data - output = T(0) while not(pb.isEmpty()): var header: ProtoHeader - if not(pb.getHeader(header)): - output = T(0) - return false + ? pb.getHeader(header) let wireCheck = when (T is uint64) or (T is uint32) or (T is uint) or (T is zint64) or (T is zint32) or (T is zint) or @@ -533,28 +535,29 @@ proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int, header.wire == ProtoFieldKind.Fixed64 if header.index == uint64(field): if wireCheck: - let r = getValue(pb, header, value) - case r - of ProtoResult.NoError: + var value: T + let vres = pb.getValue(header, value) + if vres.isOk(): res = true - output = value + current = value else: - return false + return err(vres.error) else: # We are ignoring wire types different from what we expect, because it # is how `protoc` is working. - if not(skipValue(pb, header)): - output = T(0) - return false + ? pb.skipValue(header) else: - if not(skipValue(pb, header)): - output = T(0) - return false - res + ? pb.skipValue(header) + + if res: + output = current + ok(true) + else: + ok(false) proc getField*[T: byte|char](data: ProtoBuffer, field: int, output: var openarray[T], - outlen: var int): bool = + outlen: var int): ProtoResult[bool] = checkFieldNumber(field) var pb = data var res = false @@ -563,182 +566,191 @@ proc getField*[T: byte|char](data: ProtoBuffer, field: int, while not(pb.isEmpty()): var header: ProtoHeader - if not(pb.getHeader(header)): + let hres = pb.getHeader(header) + if hres.isErr(): if len(output) > 0: zeroMem(addr output[0], len(output)) outlen = 0 - return false - + return err(hres.error) if header.index == uint64(field): if header.wire == ProtoFieldKind.Length: - let r = getValue(pb, header, output, outlen) - case r - of ProtoResult.NoError: + let vres = pb.getValue(header, output, outlen) + if vres.isOk(): res = true - of ProtoResult.BufferOverflowError: + else: # Buffer overflow error is not critical error, we still can get # field values with proper size. - discard - else: - if len(output) > 0: - zeroMem(addr output[0], len(output)) - return false + if vres.error != ProtoError.BufferOverflow: + if len(output) > 0: + zeroMem(addr output[0], len(output)) + outlen = 0 + return err(vres.error) else: # We are ignoring wire types different from ProtoFieldKind.Length, # because it is how `protoc` is working. - if not(skipValue(pb, header)): + let sres = pb.skipValue(header) + if sres.isErr(): if len(output) > 0: zeroMem(addr output[0], len(output)) outlen = 0 - return false + return err(sres.error) else: - if not(skipValue(pb, header)): + let sres = pb.skipValue(header) + if sres.isErr(): if len(output) > 0: zeroMem(addr output[0], len(output)) outlen = 0 - return false + return err(sres.error) - res + if res: + ok(true) + else: + ok(false) proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int, - output: var T): bool = + output: var T): ProtoResult[bool] = checkFieldNumber(field) var res = false var pb = data while not(pb.isEmpty()): var header: ProtoHeader - if not(pb.getHeader(header)): + let hres = pb.getHeader(header) + if hres.isErr(): output.setLen(0) - return false - + return err(hres.error) if header.index == uint64(field): if header.wire == ProtoFieldKind.Length: - let r = getValue(pb, header, output) - case r - of ProtoResult.NoError: + let vres = pb.getValue(header, output) + if vres.isOk(): res = true - of ProtoResult.BufferOverflowError: - # Buffer overflow error is not critical error, we still can get - # field values with proper size. - discard else: output.setLen(0) - return false + return err(vres.error) else: # We are ignoring wire types different from ProtoFieldKind.Length, # because it is how `protoc` is working. - if not(skipValue(pb, header)): + let sres = pb.skipValue(header) + if sres.isErr(): output.setLen(0) - return false + return err(sres.error) else: - if not(skipValue(pb, header)): + let sres = pb.skipValue(header) + if sres.isErr(): output.setLen(0) - return false - - res - -proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {. - inline.} = - var buffer: seq[byte] - if pb.getField(field, buffer): - output = initProtoBuffer(buffer) - true + return err(sres.error) + if res: + ok(true) else: - false + ok(false) + +proc getField*(pb: ProtoBuffer, field: int, + output: var ProtoBuffer): ProtoResult[bool] {.inline.} = + var buffer: seq[byte] + let res = pb.getField(field, buffer) + if res.isOk(): + if res.get(): + output = initProtoBuffer(buffer) + ok(true) + else: + ok(false) + else: + err(res.error) proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int, - output: var seq[T]): bool = + output: var seq[T]): ProtoResult[bool] = checkFieldNumber(field) var pb = data output.setLen(0) while not(pb.isEmpty()): var header: ProtoHeader - if not(pb.getHeader(header)): + let hres = pb.getHeader(header) + if hres.isErr(): output.setLen(0) - return false - + return err(hres.error) if header.index == uint64(field): if header.wire == ProtoFieldKind.Length: var item: T - let r = getValue(pb, header, item) - case r - of ProtoResult.NoError: + let vres = pb.getValue(header, item) + if vres.isOk(): output.add(item) else: output.setLen(0) - return false + return err(vres.error) else: - if not(skipValue(pb, header)): + let sres = pb.skipValue(header) + if sres.isErr(): output.setLen(0) - return false + return err(sres.error) else: - if not(skipValue(pb, header)): + let sres = pb.skipValue(header) + if sres.isErr(): output.setLen(0) - return false + return err(sres.error) if len(output) > 0: - true + ok(true) else: - false + ok(false) -proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer, - field: int, - output: var seq[T]): bool = +proc getRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int, + output: var seq[T]): ProtoResult[bool] = checkFieldNumber(field) var pb = data output.setLen(0) while not(pb.isEmpty()): var header: ProtoHeader - if not(pb.getHeader(header)): + let hres = pb.getHeader(header) + if hres.isErr(): output.setLen(0) - return false + return err(hres.error) if header.index == uint64(field): if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32, ProtoFieldKind.Fixed64}: var item: T - let r = getValue(pb, header, item) - case r - of ProtoResult.NoError: + let vres = getValue(pb, header, item) + if vres.isOk(): output.add(item) else: output.setLen(0) - return false + return err(vres.error) else: - if not(skipValue(pb, header)): + let sres = skipValue(pb, header) + if sres.isErr(): output.setLen(0) - return false + return err(sres.error) else: - if not(skipValue(pb, header)): + let sres = skipValue(pb, header) + if sres.isErr(): output.setLen(0) - return false + return err(sres.error) if len(output) > 0: - true + ok(true) else: - false + ok(false) proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int, - output: var seq[T]): bool = + output: var seq[T]): ProtoResult[bool] = checkFieldNumber(field) var pb = data output.setLen(0) while not(pb.isEmpty()): var header: ProtoHeader - if not(pb.getHeader(header)): + let hres = pb.getHeader(header) + if hres.isErr(): output.setLen(0) - return false + return err(hres.error) if header.index == uint64(field): if header.wire == ProtoFieldKind.Length: var arritem: seq[byte] - let rarr = getValue(pb, header, arritem) - case rarr - of ProtoResult.NoError: + let ares = getValue(pb, header, arritem) + if ares.isOk(): var pbarr = initProtoBuffer(arritem) let itemHeader = when (T is uint64) or (T is uint32) or (T is uint) or @@ -751,29 +763,30 @@ proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int, ProtoHeader(wire: ProtoFieldKind.Fixed64) while not(pbarr.isEmpty()): var item: T - let res = getValue(pbarr, itemHeader, item) - case res - of ProtoResult.NoError: + let vres = getValue(pbarr, itemHeader, item) + if vres.isOk(): output.add(item) else: output.setLen(0) - return false + return err(vres.error) else: output.setLen(0) - return false + return err(ares.error) else: - if not(skipValue(pb, header)): + let sres = skipValue(pb, header) + if sres.isErr(): output.setLen(0) - return false + return err(sres.error) else: - if not(skipValue(pb, header)): + let sres = skipValue(pb, header) + if sres.isErr(): output.setLen(0) - return false + return err(sres.error) if len(output) > 0: - true + ok(true) else: - false + ok(false) proc getVarintValue*(data: var ProtoBuffer, field: int, value: var SomeVarint): int {.deprecated.} = diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index a998ea9..4de5d49 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -46,52 +46,56 @@ type proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer = result = initProtoBuffer() - result.write(1, peerInfo.publicKey.get().getBytes().tryGet()) - for ma in peerInfo.addrs: result.write(2, ma.data.buffer) - for proto in peerInfo.protocols: result.write(3, proto) - result.write(4, observedAddr.data.buffer) - let protoVersion = ProtoVersion result.write(5, protoVersion) - let agentVersion = AgentVersion result.write(6, agentVersion) result.finish() -proc decodeMsg*(buf: seq[byte]): IdentifyInfo = +proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] = + var + iinfo: IdentifyInfo + pubKey: PublicKey + oaddr: MultiAddress + protoVersion: string + agentVersion: string + var pb = initProtoBuffer(buf) - var pubKey: PublicKey - if pb.getField(1, pubKey): - trace "read public key from message", pubKey = ($pubKey).shortLog - result.pubKey = some(pubKey) + let r1 = pb.getField(1, pubKey) + let r2 = pb.getRepeatedField(2, iinfo.addrs) + let r3 = pb.getRepeatedField(3, iinfo.protos) + let r4 = pb.getField(4, oaddr) + let r5 = pb.getField(5, protoVersion) + let r6 = pb.getField(6, agentVersion) - if pb.getRepeatedField(2, result.addrs): - trace "read addresses from message", addresses = result.addrs + let res = r1.isOk() and r2.isOk() and r3.isOk() and + r4.isOk() and r5.isOk() and r6.isOk() - if pb.getRepeatedField(3, result.protos): - trace "read protos from message", protocols = result.protos - - var observableAddr: MultiAddress - if pb.getField(4, observableAddr): - trace "read observableAddr from message", address = observableAddr - result.observedAddr = some(observableAddr) - - var protoVersion = "" - if pb.getField(5, protoVersion): - trace "read protoVersion from message", protoVersion = protoVersion - result.protoVersion = some(protoVersion) - - var agentVersion = "" - if pb.getField(6, agentVersion): - trace "read agentVersion from message", agentVersion = agentVersion - result.agentVersion = some(agentVersion) + if res: + if r1.get(): + iinfo.pubKey = some(pubKey) + if r4.get(): + iinfo.observedAddr = some(oaddr) + if r5.get(): + iinfo.protoVersion = some(protoVersion) + if r6.get(): + iinfo.agentVersion = some(agentVersion) + trace "decodeMsg: decoded message", pubkey = ($pubKey).shortLog, + addresses = $iinfo.addrs, protocols = $iinfo.protos, + observable_address = $iinfo.observedAddr, + proto_version = $iinfo.protoVersion, + agent_version = $iinfo.agentVersion + some(iinfo) + else: + trace "decodeMsg: failed to decode received message" + none[IdentifyInfo]() proc newIdentify*(peerInfo: PeerInfo): Identify = new result @@ -122,11 +126,13 @@ proc identify*(p: Identify, trace "initiating identify", peer = $conn var message = await conn.readLp(64*1024) if len(message) == 0: - trace "identify: Invalid or empty message received!" - raise newException(IdentityInvalidMsgError, - "Invalid or empty message received!") + trace "identify: Empty message received!" + raise newException(IdentityInvalidMsgError, "Empty message received!") - result = decodeMsg(message) + let infoOpt = decodeMsg(message) + if infoOpt.isNone(): + raise newException(IdentityInvalidMsgError, "Incorrect message received!") + result = infoOpt.get() if not isNil(remotePeerInfo) and result.pubKey.isSome: let peer = PeerID.init(result.pubKey.get()) diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index ca72a32..612a5eb 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -43,7 +43,7 @@ type RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.} -func hash*(p: PubSubPeer): Hash = +func hash*(p: PubSubPeer): Hash = # int is either 32/64, so intptr basically, pubsubpeer is a ref cast[pointer](p).hash @@ -114,7 +114,13 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} = trace "message already received, skipping", peer = p.id continue - var msg = decodeRpcMsg(data) + var rmsg = decodeRpcMsg(data) + if rmsg.isErr(): + notice "failed to decode msg from peer", peer = p.id + break + + var msg = rmsg.get() + trace "decoded msg from peer", peer = p.id, msg = msg.shortLog # trigger hooks p.recvObservers(msg) @@ -149,11 +155,11 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = p.sendObservers(mm) let encoded = encodeRpcMsg(mm) - if encoded.buffer.len <= 0: + if encoded.len <= 0: trace "empty message, skipping", peer = p.id return - let digest = $(sha256.digest(encoded.buffer)) + let digest = $(sha256.digest(encoded)) if digest in p.sentRpcCache: trace "message already sent to peer, skipping", peer = p.id libp2p_pubsub_skipped_sent_messages.inc(labelValues = [p.id]) @@ -164,8 +170,8 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = encoded = digest if p.connected: # this can happen if the remote disconnected trace "sending encoded msgs to peer", peer = p.id, - encoded = encoded.buffer.shortLog - await p.sendConn.writeLp(encoded.buffer) + encoded = encoded.shortLog + await p.sendConn.writeLp(encoded) p.sentRpcCache.put(digest) for m in msgs: diff --git a/libp2p/protocols/pubsub/rpc/protobuf.nim b/libp2p/protocols/pubsub/rpc/protobuf.nim index c5a3eb3..4958320 100644 --- a/libp2p/protocols/pubsub/rpc/protobuf.nim +++ b/libp2p/protocols/pubsub/rpc/protobuf.nim @@ -80,163 +80,151 @@ proc encodeMessage*(msg: Message): seq[byte] = proc write*(pb: var ProtoBuffer, field: int, msg: Message) = pb.write(field, encodeMessage(msg)) -proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} = +proc decodeGraft*(pb: ProtoBuffer): ProtoResult[ControlGraft] {. + inline.} = trace "decodeGraft: decoding message" var control = ControlGraft() - var topicId: string - if pb.getField(1, topicId): - control.topicId = topicId - trace "decodeGraft: read topicId", topic_id = topicId + if ? pb.getField(1, control.topicId): + trace "decodeGraft: read topicId", topic_id = control.topicId else: trace "decodeGraft: topicId is missing" - control + ok(control) -proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} = +proc decodePrune*(pb: ProtoBuffer): ProtoResult[ControlPrune] {. + inline.} = trace "decodePrune: decoding message" var control = ControlPrune() - var topicId: string - if pb.getField(1, topicId): - control.topicId = topicId - trace "decodePrune: read topicId", topic_id = topicId + if ? pb.getField(1, control.topicId): + trace "decodePrune: read topicId", topic_id = control.topicId else: trace "decodePrune: topicId is missing" - control + ok(control) -proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} = +proc decodeIHave*(pb: ProtoBuffer): ProtoResult[ControlIHave] {. + inline.} = trace "decodeIHave: decoding message" var control = ControlIHave() - var topicId: string - if pb.getField(1, topicId): - control.topicId = topicId - trace "decodeIHave: read topicId", topic_id = topicId + if ? pb.getField(1, control.topicId): + trace "decodeIHave: read topicId", topic_id = control.topicId else: trace "decodeIHave: topicId is missing" - if pb.getRepeatedField(2, control.messageIDs): + if ? pb.getRepeatedField(2, control.messageIDs): trace "decodeIHave: read messageIDs", message_ids = control.messageIDs else: trace "decodeIHave: no messageIDs" - control + ok(control) -proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} = +proc decodeIWant*(pb: ProtoBuffer): ProtoResult[ControlIWant] {.inline.} = trace "decodeIWant: decoding message" var control = ControlIWant() - if pb.getRepeatedField(1, control.messageIDs): + if ? pb.getRepeatedField(1, control.messageIDs): trace "decodeIWant: read messageIDs", message_ids = control.messageIDs else: trace "decodeIWant: no messageIDs" + ok(control) -proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} = +proc decodeControl*(pb: ProtoBuffer): ProtoResult[Option[ControlMessage]] {. + inline.} = trace "decodeControl: decoding message" var buffer: seq[byte] - if pb.getField(3, buffer): + if ? pb.getField(3, buffer): var control: ControlMessage var cpb = initProtoBuffer(buffer) var ihavepbs: seq[seq[byte]] var iwantpbs: seq[seq[byte]] var graftpbs: seq[seq[byte]] var prunepbs: seq[seq[byte]] - - discard cpb.getRepeatedField(1, ihavepbs) - discard cpb.getRepeatedField(2, iwantpbs) - discard cpb.getRepeatedField(3, graftpbs) - discard cpb.getRepeatedField(4, prunepbs) - - for item in ihavepbs: - control.ihave.add(decodeIHave(initProtoBuffer(item))) - for item in iwantpbs: - control.iwant.add(decodeIWant(initProtoBuffer(item))) - for item in graftpbs: - control.graft.add(decodeGraft(initProtoBuffer(item))) - for item in prunepbs: - control.prune.add(decodePrune(initProtoBuffer(item))) - - trace "decodeControl: " - some(control) + if ? cpb.getRepeatedField(1, ihavepbs): + for item in ihavepbs: + control.ihave.add(? decodeIHave(initProtoBuffer(item))) + if ? cpb.getRepeatedField(2, iwantpbs): + for item in iwantpbs: + control.iwant.add(? decodeIWant(initProtoBuffer(item))) + if ? cpb.getRepeatedField(3, graftpbs): + for item in graftpbs: + control.graft.add(? decodeGraft(initProtoBuffer(item))) + if ? cpb.getRepeatedField(4, prunepbs): + for item in prunepbs: + control.prune.add(? decodePrune(initProtoBuffer(item))) + trace "decodeControl: message statistics", graft_count = len(control.graft), + prune_count = len(control.prune), + ihave_count = len(control.ihave), + iwant_count = len(control.iwant) + ok(some(control)) else: - none[ControlMessage]() + ok(none[ControlMessage]()) -proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} = +proc decodeSubscription*(pb: ProtoBuffer): ProtoResult[SubOpts] {.inline.} = trace "decodeSubscription: decoding message" var subflag: uint64 var sub = SubOpts() - if pb.getField(1, subflag): + if ? pb.getField(1, subflag): sub.subscribe = bool(subflag) trace "decodeSubscription: read subscribe", subscribe = subflag else: trace "decodeSubscription: subscribe is missing" - if pb.getField(2, sub.topic): + if ? pb.getField(2, sub.topic): trace "decodeSubscription: read topic", topic = sub.topic else: trace "decodeSubscription: topic is missing" + ok(sub) - sub - -proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} = +proc decodeSubscriptions*(pb: ProtoBuffer): ProtoResult[seq[SubOpts]] {. + inline.} = trace "decodeSubscriptions: decoding message" var subpbs: seq[seq[byte]] var subs: seq[SubOpts] - if pb.getRepeatedField(1, subpbs): + let res = ? pb.getRepeatedField(1, subpbs) + if res: trace "decodeSubscriptions: read subscriptions", count = len(subpbs) for item in subpbs: - let sub = decodeSubscription(initProtoBuffer(item)) - subs.add(sub) + subs.add(? decodeSubscription(initProtoBuffer(item))) + if len(subs) == 0: + trace "decodeSubscription: no subscriptions found" + ok(subs) - if len(subs) == 0: - trace "decodeSubscription: no subscriptions found" - - subs - -proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} = +proc decodeMessage*(pb: ProtoBuffer): ProtoResult[Message] {.inline.} = trace "decodeMessage: decoding message" var msg: Message - if pb.getField(1, msg.fromPeer): + if ? pb.getField(1, msg.fromPeer): trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty() else: trace "decodeMessage: fromPeer is missing" - - if pb.getField(2, msg.data): + if ? pb.getField(2, msg.data): trace "decodeMessage: read data", data = msg.data.shortLog() else: trace "decodeMessage: data is missing" - - if pb.getField(3, msg.seqno): + if ? pb.getField(3, msg.seqno): trace "decodeMessage: read seqno", seqno = msg.data.shortLog() else: trace "decodeMessage: seqno is missing" - - if pb.getRepeatedField(4, msg.topicIDs): + if ? pb.getRepeatedField(4, msg.topicIDs): trace "decodeMessage: read topics", topic_ids = msg.topicIDs else: trace "decodeMessage: topics are missing" - - if pb.getField(5, msg.signature): + if ? pb.getField(5, msg.signature): trace "decodeMessage: read signature", signature = msg.signature.shortLog() else: trace "decodeMessage: signature is missing" - - if pb.getField(6, msg.key): + if ? pb.getField(6, msg.key): trace "decodeMessage: read public key", key = msg.key.shortLog() else: trace "decodeMessage: public key is missing" + ok(msg) - msg - -proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} = +proc decodeMessages*(pb: ProtoBuffer): ProtoResult[seq[Message]] {.inline.} = trace "decodeMessages: decoding message" var msgpbs: seq[seq[byte]] var msgs: seq[Message] - if pb.getRepeatedField(2, msgpbs): + if ? pb.getRepeatedField(2, msgpbs): trace "decodeMessages: read messages", count = len(msgpbs) for item in msgpbs: - let msg = decodeMessage(initProtoBuffer(item)) - msgs.add(msg) - - if len(msgs) == 0: + msgs.add(? decodeMessage(initProtoBuffer(item))) + else: trace "decodeMessages: no messages found" + ok(msgs) - msgs - -proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer = +proc encodeRpcMsg*(msg: RPCMsg): seq[byte] = trace "encodeRpcMsg: encoding message", msg = msg.shortLog() var pb = initProtoBuffer() for item in msg.subscriptions: @@ -247,14 +235,13 @@ proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer = pb.write(3, msg.control.get()) if len(pb.buffer) > 0: pb.finish() - result = pb + pb.buffer -proc decodeRpcMsg*(msg: seq[byte]): RPCMsg = +proc decodeRpcMsg*(msg: seq[byte]): ProtoResult[RPCMsg] {.inline.} = trace "decodeRpcMsg: decoding message", msg = msg.shortLog() var pb = initProtoBuffer(msg) var rpcMsg: RPCMsg - rpcMsg.messages = pb.decodeMessages() - rpcMsg.subscriptions = pb.decodeSubscriptions() - rpcMsg.control = pb.decodeControl() - - rpcMsg + rpcMsg.messages = ? pb.decodeMessages() + rpcMsg.subscriptions = ? pb.decodeSubscriptions() + rpcMsg.control = ? pb.decodeControl() + ok(rpcMsg) diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index d5398cc..c34be29 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -449,9 +449,11 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon remoteSig: Signature remoteSigBytes: seq[byte] - if not(remoteProof.getField(1, remotePubKeyBytes)): + let r1 = remoteProof.getField(1, remotePubKeyBytes) + let r2 = remoteProof.getField(2, remoteSigBytes) + if r1.isErr() or not(r1.get()): raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if not(remoteProof.getField(2, remoteSigBytes)): + if r2.isErr() or not(r2.get()): raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") if not remotePubKey.init(remotePubKeyBytes): diff --git a/tests/testminprotobuf.nim b/tests/testminprotobuf.nim index a4fe7fe..d167285 100644 --- a/tests/testminprotobuf.nim +++ b/tests/testminprotobuf.nim @@ -88,7 +88,7 @@ suite "MinProtobuf test suite": var value: uint64 var pb = initProtoBuffer(data) let res = pb.getField(1, value) - doAssert(res) + doAssert(res.isOk() == true and res.get() == true) value proc getFixed32EncodedValue(value: float32): seq[byte] = @@ -101,7 +101,7 @@ suite "MinProtobuf test suite": var value: float32 var pb = initProtoBuffer(data) let res = pb.getField(1, value) - doAssert(res) + doAssert(res.isOk() == true and res.get() == true) cast[uint32](value) proc getFixed64EncodedValue(value: float64): seq[byte] = @@ -114,7 +114,7 @@ suite "MinProtobuf test suite": var value: float64 var pb = initProtoBuffer(data) let res = pb.getField(1, value) - doAssert(res) + doAssert(res.isOk() == true and res.get() == true) cast[uint64](value) proc getLengthEncodedValue(value: string): seq[byte] = @@ -134,8 +134,7 @@ suite "MinProtobuf test suite": var valueLen = 0 var pb = initProtoBuffer(data) let res = pb.getField(1, value, valueLen) - - doAssert(res) + doAssert(res.isOk() == true and res.get() == true) value.setLen(valueLen) value @@ -173,17 +172,19 @@ suite "MinProtobuf test suite": # corrupting data.setLen(len(data) - 1) var pb = initProtoBuffer(data) + let res = pb.getField(1, value) check: - pb.getField(1, value) == false + res.isErr() == true test "[varint] non-existent field test": for i in 0 ..< len(VarintValues): var value: uint64 var data = getVarintEncodedValue(VarintValues[i]) var pb = initProtoBuffer(data) + let res = pb.getField(2, value) check: - pb.getField(2, value) == false - value == 0'u64 + res.isOk() == true + res.get() == false test "[varint] corrupted header test": for i in 0 ..< len(VarintValues): @@ -192,15 +193,17 @@ suite "MinProtobuf test suite": var data = getVarintEncodedValue(VarintValues[i]) data.corruptHeader(k) var pb = initProtoBuffer(data) + let res = pb.getField(1, value) check: - pb.getField(1, value) == false + res.isErr() == true test "[varint] empty buffer test": var value: uint64 var pb = initProtoBuffer() + let res = pb.getField(1, value) check: - pb.getField(1, value) == false - value == 0'u64 + res.isOk() == true + res.get() == false test "[varint] Repeated field test": var pb1 = initProtoBuffer() @@ -218,9 +221,12 @@ suite "MinProtobuf test suite": let r2 = pb2.getRepeatedField(2, fieldarr2) let r3 = pb2.getRepeatedField(3, fieldarr3) check: - r1 == true - r2 == true - r3 == false + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r1.get() == true + r2.get() == true + r3.get() == false len(fieldarr3) == 0 len(fieldarr2) == 1 len(fieldarr1) == 4 @@ -246,9 +252,12 @@ suite "MinProtobuf test suite": let r2 = pb2.getPackedRepeatedField(2, fieldarr2) let r3 = pb2.getPackedRepeatedField(3, fieldarr3) check: - r1 == true - r2 == true - r3 == false + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r1.get() == true + r2.get() == true + r3.get() == false len(fieldarr3) == 0 len(fieldarr2) == 2 len(fieldarr1) == 6 @@ -284,17 +293,19 @@ suite "MinProtobuf test suite": # corrupting data.setLen(len(data) - 1) var pb = initProtoBuffer(data) + let res = pb.getField(1, value) check: - pb.getField(1, value) == false + res.isErr() == true test "[fixed32] non-existent field test": for i in 0 ..< len(Fixed32Values): var value: float32 var data = getFixed32EncodedValue(float32(Fixed32Values[i])) var pb = initProtoBuffer(data) + let res = pb.getField(2, value) check: - pb.getField(2, value) == false - value == float32(0) + res.isOk() == true + res.get() == false test "[fixed32] corrupted header test": for i in 0 ..< len(Fixed32Values): @@ -303,15 +314,17 @@ suite "MinProtobuf test suite": var data = getFixed32EncodedValue(float32(Fixed32Values[i])) data.corruptHeader(k) var pb = initProtoBuffer(data) + let res = pb.getField(1, value) check: - pb.getField(1, value) == false + res.isErr() == true test "[fixed32] empty buffer test": var value: float32 var pb = initProtoBuffer() + let res = pb.getField(1, value) check: - pb.getField(1, value) == false - value == float32(0) + res.isOk() == true + res.get() == false test "[fixed32] Repeated field test": var pb1 = initProtoBuffer() @@ -329,9 +342,12 @@ suite "MinProtobuf test suite": let r2 = pb2.getRepeatedField(2, fieldarr2) let r3 = pb2.getRepeatedField(3, fieldarr3) check: - r1 == true - r2 == true - r3 == false + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r1.get() == true + r2.get() == true + r3.get() == false len(fieldarr3) == 0 len(fieldarr2) == 1 len(fieldarr1) == 4 @@ -360,9 +376,12 @@ suite "MinProtobuf test suite": let r2 = pb2.getPackedRepeatedField(2, fieldarr2) let r3 = pb2.getPackedRepeatedField(3, fieldarr3) check: - r1 == true - r2 == true - r3 == false + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r1.get() == true + r2.get() == true + r3.get() == false len(fieldarr3) == 0 len(fieldarr2) == 2 len(fieldarr1) == 5 @@ -397,17 +416,19 @@ suite "MinProtobuf test suite": # corrupting data.setLen(len(data) - 1) var pb = initProtoBuffer(data) + let res = pb.getField(1, value) check: - pb.getField(1, value) == false + res.isErr() == true test "[fixed64] non-existent field test": for i in 0 ..< len(Fixed64Values): var value: float64 var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) var pb = initProtoBuffer(data) + let res = pb.getField(2, value) check: - pb.getField(2, value) == false - value == float64(0) + res.isOk() == true + res.get() == false test "[fixed64] corrupted header test": for i in 0 ..< len(Fixed64Values): @@ -416,15 +437,17 @@ suite "MinProtobuf test suite": var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) data.corruptHeader(k) var pb = initProtoBuffer(data) + let res = pb.getField(1, value) check: - pb.getField(1, value) == false + res.isErr() == true test "[fixed64] empty buffer test": var value: float64 var pb = initProtoBuffer() + let res = pb.getField(1, value) check: - pb.getField(1, value) == false - value == float64(0) + res.isOk() == true + res.get() == false test "[fixed64] Repeated field test": var pb1 = initProtoBuffer() @@ -442,9 +465,12 @@ suite "MinProtobuf test suite": let r2 = pb2.getRepeatedField(2, fieldarr2) let r3 = pb2.getRepeatedField(3, fieldarr3) check: - r1 == true - r2 == true - r3 == false + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r1.get() == true + r2.get() == true + r3.get() == false len(fieldarr3) == 0 len(fieldarr2) == 1 len(fieldarr1) == 4 @@ -474,9 +500,12 @@ suite "MinProtobuf test suite": let r2 = pb2.getPackedRepeatedField(2, fieldarr2) let r3 = pb2.getPackedRepeatedField(3, fieldarr3) check: - r1 == true - r2 == true - r3 == false + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r1.get() == true + r2.get() == true + r3.get() == false len(fieldarr3) == 0 len(fieldarr2) == 2 len(fieldarr1) == 8 @@ -523,8 +552,9 @@ suite "MinProtobuf test suite": # corrupting data.setLen(len(data) - 1) var pb = initProtoBuffer(data) + let res = pb.getField(1, value, valueLen) check: - pb.getField(1, value, valueLen) == false + res.isErr() == true test "[length] non-existent field test": for i in 0 ..< len(LengthValues): @@ -532,8 +562,10 @@ suite "MinProtobuf test suite": var valueLen = 0 var data = getLengthEncodedValue(LengthValues[i]) var pb = initProtoBuffer(data) + let res = pb.getField(2, value, valueLen) check: - pb.getField(2, value, valueLen) == false + res.isOk() == true + res.get() == false valueLen == 0 test "[length] corrupted header test": @@ -544,15 +576,18 @@ suite "MinProtobuf test suite": var data = getLengthEncodedValue(LengthValues[i]) data.corruptHeader(k) var pb = initProtoBuffer(data) + let res = pb.getField(1, value, valueLen) check: - pb.getField(1, value, valueLen) == false + res.isErr() == true test "[length] empty buffer test": var value = newSeq[byte](len(LengthValues[0])) var valueLen = 0 var pb = initProtoBuffer() + let res = pb.getField(1, value, valueLen) check: - pb.getField(1, value, valueLen) == false + res.isOk() == true + res.get() == false valueLen == 0 test "[length] buffer overflow test": @@ -562,8 +597,10 @@ suite "MinProtobuf test suite": var value = newString(len(LengthValues[i]) - 1) var valueLen = 0 var pb = initProtoBuffer(data) + let res = pb.getField(1, value, valueLen) check: - pb.getField(1, value, valueLen) == false + res.isOk() == true + res.get() == false valueLen == len(LengthValues[i]) isFullZero(value) == true @@ -578,8 +615,10 @@ suite "MinProtobuf test suite": var pb2 = initProtoBuffer(pb1.buffer) var value = newString(4) var valueLen = 0 + let res = pb2.getField(1, value, valueLen) check: - pb2.getField(1, value, valueLen) == true + res.isOk() == true + res.get() == true value == "SOME" test "[length] too big message test": @@ -593,8 +632,9 @@ suite "MinProtobuf test suite": var pb2 = initProtoBuffer(pb1.buffer) var value = newString(MaxMessageSize + 1) var valueLen = 0 + let res = pb2.getField(1, value, valueLen) check: - pb2.getField(1, value, valueLen) == false + res.isErr() == true test "[length] Repeated field test": var pb1 = initProtoBuffer() @@ -612,9 +652,12 @@ suite "MinProtobuf test suite": let r2 = pb2.getRepeatedField(2, fieldarr2) let r3 = pb2.getRepeatedField(3, fieldarr3) check: - r1 == true - r2 == true - r3 == false + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r1.get() == true + r2.get() == true + r3.get() == false len(fieldarr3) == 0 len(fieldarr2) == 1 len(fieldarr1) == 4 @@ -662,11 +705,16 @@ suite "MinProtobuf test suite": var lengthValue = newString(10) var lengthSize: int + let r1 = pb.getField(1, varintValue) + let r2 = pb.getField(2, fixed32Value) + let r3 = pb.getField(3, fixed64Value) + let r4 = pb.getField(4, lengthValue, lengthSize) + check: - pb.getField(1, varintValue) == true - pb.getField(2, fixed32Value) == true - pb.getField(3, fixed64Value) == true - pb.getField(4, lengthValue, lengthSize) == true + r1.isOk() == true + r2.isOk() == true + r3.isOk() == true + r4.isOk() == true lengthValue.setLen(lengthSize)