diff --git a/libp2p/crypto/crypto.nim b/libp2p/crypto/crypto.nim index 78b7dd519..893a2245f 100644 --- a/libp2p/crypto/crypto.nim +++ b/libp2p/crypto/crypto.nim @@ -222,8 +222,8 @@ proc toBytes*(key: PrivateKey, data: var openarray[byte]): CryptoResult[int] = ## ## Returns number of bytes (octets) needed to store private key ``key``. var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() var blen = len(msg.buffer) if len(data) >= blen: @@ -236,8 +236,8 @@ proc toBytes*(key: PublicKey, data: var openarray[byte]): CryptoResult[int] = ## ## Returns number of bytes (octets) needed to store public key ``key``. var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() var blen = len(msg.buffer) if len(data) >= blen and blen > 0: @@ -256,8 +256,8 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] = ## Return private key ``key`` in binary form (using libp2p's protobuf ## serialization). var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() ok(msg.buffer) @@ -265,8 +265,8 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] = ## Return public key ``key`` in binary form (using libp2p's protobuf ## serialization). var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() ok(msg.buffer) @@ -283,33 +283,32 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool = var buffer: seq[byte] if len(data) > 0: var pb = initProtoBuffer(@data) - if pb.getVarintValue(1, id) != 0: - if pb.getBytes(2, buffer) != 0: - if cast[int8](id) in SupportedSchemesInt: - var scheme = cast[PKScheme](cast[int8](id)) - when key is PrivateKey: - var nkey = PrivateKey(scheme: scheme) - else: - var nkey = PublicKey(scheme: scheme) - case scheme: - of PKScheme.RSA: - if init(nkey.rsakey, buffer).isOk: - key = nkey - return true - of PKScheme.Ed25519: - if init(nkey.edkey, buffer): - key = nkey - return true - of PKScheme.ECDSA: - if init(nkey.eckey, buffer).isOk: - key = nkey - return true - of PKScheme.Secp256k1: - if init(nkey.skkey, buffer).isOk: - key = nkey - return true - else: - return false + if pb.getField(1, id) and pb.getField(2, buffer): + if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0: + var scheme = cast[PKScheme](cast[int8](id)) + when key is PrivateKey: + var nkey = PrivateKey(scheme: scheme) + else: + var nkey = PublicKey(scheme: scheme) + case scheme: + of PKScheme.RSA: + if init(nkey.rsakey, buffer).isOk: + key = nkey + return true + of PKScheme.Ed25519: + if init(nkey.edkey, buffer): + key = nkey + return true + of PKScheme.ECDSA: + if init(nkey.eckey, buffer).isOk: + key = nkey + return true + of PKScheme.Secp256k1: + if init(nkey.skkey, buffer).isOk: + key = nkey + return true + else: + return false proc init*(sig: var Signature, data: openarray[byte]): bool = ## Initialize signature ``sig`` from raw binary form. @@ -727,11 +726,11 @@ proc createProposal*(nonce, pubkey: openarray[byte], ## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and ## comma-delimeted list of supported hashes ``hashes``. var msg = initProtoBuffer({WithUint32BeLength}) - msg.write(initProtoField(1, nonce)) - msg.write(initProtoField(2, pubkey)) - msg.write(initProtoField(3, exchanges)) - msg.write(initProtoField(4, ciphers)) - msg.write(initProtoField(5, hashes)) + msg.write(1, nonce) + msg.write(2, pubkey) + msg.write(3, exchanges) + msg.write(4, ciphers) + msg.write(5, hashes) msg.finish() shallowCopy(result, msg.buffer) @@ -744,19 +743,16 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte], ## ## Procedure returns ``true`` on success and ``false`` on error. var pb = initProtoBuffer(message) - if pb.getLengthValue(1, nonce) != -1 and - pb.getLengthValue(2, pubkey) != -1 and - pb.getLengthValue(3, exchanges) != -1 and - pb.getLengthValue(4, ciphers) != -1 and - pb.getLengthValue(5, hashes) != -1: - result = true + pb.getField(1, nonce) and pb.getField(2, pubkey) and + pb.getField(3, exchanges) and pb.getField(4, ciphers) and + pb.getField(5, hashes) proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] = ## Create SecIO exchange message using ephemeral public key ``epubkey`` and ## signature of proposal blocks ``signature``. var msg = initProtoBuffer({WithUint32BeLength}) - msg.write(initProtoField(1, epubkey)) - msg.write(initProtoField(2, signature)) + msg.write(1, epubkey) + msg.write(2, signature) msg.finish() shallowCopy(result, msg.buffer) @@ -767,9 +763,7 @@ proc decodeExchange*(message: seq[byte], ## ## Procedure returns ``true`` on success and ``false`` on error. var pb = initProtoBuffer(message) - if pb.getLengthValue(1, pubkey) != -1 and - pb.getLengthValue(2, signature) != -1: - result = true + pb.getField(1, pubkey) and pb.getField(2, signature) ## Serialization/Deserialization helpers @@ -788,22 +782,27 @@ proc write*(vb: var VBuffer, sig: PrivateKey) {. ## Write Signature value ``sig`` to buffer ``vb``. vb.writeSeq(sig.getBytes().tryGet()) -proc initProtoField*(index: int, pubkey: PublicKey): ProtoField {. - raises: [Defect, ResultError[CryptoError]].} = - ## Initialize ProtoField with PublicKey ``pubkey``. - result = initProtoField(index, pubkey.getBytes().tryGet()) +proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int, + key: T) {. + inline, raises: [Defect, ResultError[CryptoError]].} = + write(pb, field, key.getBytes().tryGet()) -proc initProtoField*(index: int, seckey: PrivateKey): ProtoField {. - raises: [Defect, ResultError[CryptoError]].} = - ## Initialize ProtoField with PrivateKey ``seckey``. - result = initProtoField(index, seckey.getBytes().tryGet()) +proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {. + inline, raises: [Defect, ResultError[CryptoError]].} = + write(pb, field, sig.getBytes()) -proc initProtoField*(index: int, sig: Signature): ProtoField = +proc initProtoField*(index: int, key: PublicKey|PrivateKey): ProtoField {. + deprecated, raises: [Defect, ResultError[CryptoError]].} = + ## Initialize ProtoField with PublicKey/PrivateKey ``key``. + result = initProtoField(index, key.getBytes().tryGet()) + +proc initProtoField*(index: int, sig: Signature): ProtoField {.deprecated.} = ## Initialize ProtoField with Signature ``sig``. result = initProtoField(index, sig.getBytes()) -proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int = - ## Read ``PublicKey`` from ProtoBuf's message and validate it. +proc getValue*[T: PublicKey|PrivateKey](data: var ProtoBuffer, field: int, + value: var T): int {.deprecated.} = + ## Read PublicKey/PrivateKey from ProtoBuf's message and validate it. var buf: seq[byte] var key: PublicKey result = getLengthValue(data, field, buf) @@ -813,18 +812,8 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int = else: value = key -proc getValue*(data: var ProtoBuffer, field: int, value: var PrivateKey): int = - ## Read ``PrivateKey`` from ProtoBuf's message and validate it. - var buf: seq[byte] - var key: PrivateKey - result = getLengthValue(data, field, buf) - if result > 0: - if not key.init(buf): - result = -1 - else: - value = key - -proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int = +proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {. + deprecated.} = ## Read ``Signature`` from ProtoBuf's message and validate it. var buf: seq[byte] var sig: Signature @@ -834,3 +823,30 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int = result = -1 else: value = sig + +proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int, + value: var T): bool = + var buffer: seq[byte] + var key: T + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + if key.init(buffer): + value = key + true + else: + false + +proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool = + var buffer: seq[byte] + var sig: Signature + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + if sig.init(buffer): + value = sig + true + else: + false diff --git a/libp2p/multiaddress.nim b/libp2p/multiaddress.nim index c6064f075..3d6d85eb6 100644 --- a/libp2p/multiaddress.nim +++ b/libp2p/multiaddress.nim @@ -14,9 +14,10 @@ import nativesockets import tables, strutils, stew/shims/net import chronos -import multicodec, multihash, multibase, transcoder, vbuffer, peerid +import multicodec, multihash, multibase, transcoder, vbuffer, peerid, + protobuf/minprotobuf import stew/[base58, base32, endians2, results] -export results +export results, minprotobuf, vbuffer type MAKind* = enum @@ -477,7 +478,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] = else: ok($(proto.mcodec)) -proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] = +proc protoArgument*(ma: MultiAddress, + value: var openarray[byte]): MaResult[int] = ## Returns MultiAddress ``ma`` protocol argument value. ## ## If current MultiAddress do not have argument value, then result will be @@ -496,8 +498,8 @@ proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] var res: int if proto.kind == Fixed: res = proto.size - if len(value) >= res and - vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size: + if len(value) >= res and + vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size: err("multiaddress: Decoding protocol error") else: ok(res) @@ -580,7 +582,8 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] = let proto = CodeAddresses.getOrDefault(MultiCodec(header)) if proto.kind == None: - yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'") + yield err(MaResult[MultiAddress], "Unsupported protocol '" & + $header & "'") elif proto.kind == Fixed: data.setLen(proto.size) @@ -609,7 +612,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} = return ok(true) ok(false) -proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} = +proc `[]`*(ma: MultiAddress, + codec: MultiCodec): MaResult[MultiAddress] {.inline.} = ## Returns partial MultiAddress with MultiCodec ``codec`` and present in ## MultiAddress ``ma``. for item in ma.items: @@ -634,7 +638,8 @@ proc toString*(value: MultiAddress): MaResult[string] = return err("multiaddress: Unsupported protocol '" & $header & "'") if proto.kind in {Fixed, Length, Path}: if isNil(proto.coder.bufferToString): - return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder") + return err("multiaddress: Missing protocol '" & $(proto.mcodec) & + "' coder") if not proto.coder.bufferToString(vb.data, part): return err("multiaddress: Decoding protocol error") parts.add($(proto.mcodec)) @@ -729,12 +734,14 @@ proc init*( of None: raiseAssert "None checked above" -proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerID): MaResult[MultiAddress] {.inline.} = +proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, + value: PeerID): MaResult[MultiAddress] {.inline.} = ## Initialize MultiAddress object from protocol id ``protocol`` and peer id ## ``value``. init(mtype, protocol, value.data) -proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int): MaResult[MultiAddress] = +proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, + value: int): MaResult[MultiAddress] = ## Initialize MultiAddress object from protocol id ``protocol`` and integer ## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``, ## ``dccp`` and ``sctp`` MultiAddresses. @@ -759,7 +766,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} = if mc != InvalidMultiCodec: result = CodeAddresses.getOrDefault(mc) -proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] = +proc init*(mtype: typedesc[MultiAddress], + value: string): MaResult[MultiAddress] = ## Initialize MultiAddress object from string representation ``value``. var parts = value.trimRight('/').split('/') if len(parts[0]) != 0: @@ -776,7 +784,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] else: if proto.kind in {Fixed, Length, Path}: if isNil(proto.coder.stringToBuffer): - return err("multiaddress: Missing protocol '" & part & "' transcoder") + return err("multiaddress: Missing protocol '" & + part & "' transcoder") if offset + 1 >= len(parts): return err("multiaddress: Missing protocol '" & part & "' argument") @@ -785,14 +794,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] res.data.write(proto.mcodec) let res = proto.coder.stringToBuffer(parts[offset + 1], res.data) if not res: - return err("multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`") + return err("multiaddress: Error encoding `" & part & "/" & + parts[offset + 1] & "`") offset += 2 elif proto.kind == Path: var path = "/" & (parts[(offset + 1)..^1].join("/")) res.data.write(proto.mcodec) if not proto.coder.stringToBuffer(path, res.data): - return err("multiaddress: Error encoding `" & part & "/" & path & "`") + return err("multiaddress: Error encoding `" & part & "/" & + path & "`") break elif proto.kind == Marker: @@ -801,8 +812,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] res.data.finish() ok(res) - -proc init*(mtype: typedesc[MultiAddress], data: openarray[byte]): MaResult[MultiAddress] = +proc init*(mtype: typedesc[MultiAddress], + data: openarray[byte]): MaResult[MultiAddress] = ## Initialize MultiAddress with array of bytes ``data``. if len(data) == 0: err("multiaddress: Address could not be empty!") @@ -836,10 +847,12 @@ proc init*(mtype: typedesc[MultiAddress], var data = initVBuffer() data.write(familyProto.mcodec) var written = familyProto.coder.stringToBuffer($address, data) - doAssert written, "Merely writing a string to a buffer should always be possible" + doAssert written, + "Merely writing a string to a buffer should always be possible" data.write(protoProto.mcodec) written = protoProto.coder.stringToBuffer($port, data) - doAssert written, "Merely writing a string to a buffer should always be possible" + doAssert written, + "Merely writing a string to a buffer should always be possible" data.finish() MultiAddress(data: data) @@ -890,14 +903,16 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] = else: ok() -proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [Defect, ResultError[string]].} = +proc `&`*(m1, m2: MultiAddress): MultiAddress {. + raises: [Defect, ResultError[string]].} = ## Concatenates two addresses ``m1`` and ``m2``, and returns result. ## ## This procedure performs validation of concatenated result and can raise ## exception on error. concat(m1, m2).tryGet() -proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [Defect, ResultError[string]].} = +proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {. + raises: [Defect, ResultError[string]].} = ## Concatenates two addresses ``m1`` and ``m2``. ## ## This procedure performs validation of concatenated result and can raise @@ -1005,3 +1020,36 @@ proc `$`*(pat: MaPattern): string = result = "(" & sub.join("|") & ")" elif pat.operator == Eq: result = $pat.value + +proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} = + write(pb, field, value.data.buffer) + +proc getField*(pb: var ProtoBuffer, field: int, + value: var MultiAddress): bool {.inline.} = + var buffer: seq[byte] + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + let ma = MultiAddress.init(buffer) + if ma.isOk(): + value = ma.get() + true + else: + false + +proc getRepeatedField*(pb: var ProtoBuffer, field: int, + value: var seq[MultiAddress]): bool {.inline.} = + var items: seq[seq[byte]] + value.setLen(0) + if not(getRepeatedField(pb, field, items)): + return false + if len(items) == 0: + return true + for item in items: + let ma = MultiAddress.init(item) + if ma.isOk(): + value.add(ma.get()) + else: + value.setLen(0) + return false diff --git a/libp2p/peerid.nim b/libp2p/peerid.nim index e20417d0c..5c81c664b 100644 --- a/libp2p/peerid.nim +++ b/libp2p/peerid.nim @@ -200,11 +200,12 @@ proc write*(vb: var VBuffer, pid: PeerID) {.inline.} = ## Write PeerID value ``peerid`` to buffer ``vb``. vb.writeSeq(pid.data) -proc initProtoField*(index: int, pid: PeerID): ProtoField = +proc initProtoField*(index: int, pid: PeerID): ProtoField {.deprecated.} = ## Initialize ProtoField with PeerID ``value``. result = initProtoField(index, pid.data) -proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int = +proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int {. + deprecated.} = ## Read ``PeerID`` from ProtoBuf's message and validate it. var pid: PeerID result = getLengthValue(data, field, pid.data) @@ -213,3 +214,21 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int = result = -1 else: value = pid + +proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) = + ## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding. + write(pb, field, pid.data) + +proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool = + ## Read ``PeerID`` from ProtoBuf's message and validate it + var buffer: seq[byte] + var peerId: PeerID + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + if peerId.init(buffer): + pid = peerId + true + else: + false diff --git a/libp2p/protobuf/minprotobuf.nim b/libp2p/protobuf/minprotobuf.nim index caef22c1f..5a00c4726 100644 --- a/libp2p/protobuf/minprotobuf.nim +++ b/libp2p/protobuf/minprotobuf.nim @@ -11,7 +11,7 @@ {.push raises: [Defect].} -import ../varint +import ../varint, stew/endians2 const MaxMessageSize* = 1'u shl 22 @@ -32,10 +32,14 @@ type offset*: int length*: int + ProtoHeader* = object + wire*: ProtoFieldKind + index*: uint64 + ProtoField* = object ## Protobuf's message field representation object - index: int - case kind: ProtoFieldKind + index*: int + case kind*: ProtoFieldKind of Varint: vint*: uint64 of Fixed64: @@ -47,13 +51,35 @@ type of StartGroup, EndGroup: discard -template protoHeader*(index: int, wire: ProtoFieldKind): uint = - ## Get protobuf's field header integer for ``index`` and ``wire``. - ((uint(index) shl 3) or cast[uint](wire)) + ProtoResult {.pure.} = enum + VarintDecodeError, + MessageIncompleteError, + BufferOverflowError, + MessageSizeTooBigError, + NoError -template protoHeader*(field: ProtoField): uint = + ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 | + hint | hint32 | hint64 | float32 | float64 + +const + SupportedWireTypes* = { + int(ProtoFieldKind.Varint), + int(ProtoFieldKind.Fixed64), + int(ProtoFieldKind.Length), + int(ProtoFieldKind.Fixed32) + } + +template checkFieldNumber*(i: int) = + doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999), + "Incorrect or reserved field number") + +template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 = + ## Get protobuf's field header integer for ``index`` and ``wire``. + ((uint64(index) shl 3) or uint64(wire)) + +template getProtoHeader*(field: ProtoField): uint64 = ## Get protobuf's field header integer for ``field``. - ((uint(field.index) shl 3) or cast[uint](field.kind)) + ((uint64(field.index) shl 3) or uint64(field.kind)) template toOpenArray*(pb: ProtoBuffer): untyped = toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1) @@ -72,20 +98,20 @@ template getLen*(pb: ProtoBuffer): int = proc vsizeof*(field: ProtoField): int {.inline.} = ## Returns number of bytes required to store protobuf's field ``field``. - result = vsizeof(protoHeader(field)) case field.kind of ProtoFieldKind.Varint: - result += vsizeof(field.vint) + vsizeof(getProtoHeader(field)) + vsizeof(field.vint) of ProtoFieldKind.Fixed64: - result += sizeof(field.vfloat64) + vsizeof(getProtoHeader(field)) + sizeof(field.vfloat64) of ProtoFieldKind.Fixed32: - result += sizeof(field.vfloat32) + vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32) of ProtoFieldKind.Length: - result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer) + vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) + + len(field.vbuffer) else: - discard + 0 -proc initProtoField*(index: int, value: SomeVarint): ProtoField = +proc initProtoField*(index: int, value: SomeVarint): ProtoField {.deprecated.} = ## Initialize ProtoField with integer value. result = ProtoField(kind: Varint, index: index) when type(value) is uint64: @@ -93,26 +119,28 @@ proc initProtoField*(index: int, value: SomeVarint): ProtoField = else: result.vint = cast[uint64](value) -proc initProtoField*(index: int, value: bool): ProtoField = +proc initProtoField*(index: int, value: bool): ProtoField {.deprecated.} = ## Initialize ProtoField with integer value. result = ProtoField(kind: Varint, index: index) result.vint = byte(value) -proc initProtoField*(index: int, value: openarray[byte]): ProtoField = +proc initProtoField*(index: int, + value: openarray[byte]): ProtoField {.deprecated.} = ## Initialize ProtoField with bytes array. result = ProtoField(kind: Length, index: index) if len(value) > 0: result.vbuffer = newSeq[byte](len(value)) copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) -proc initProtoField*(index: int, value: string): ProtoField = +proc initProtoField*(index: int, value: string): ProtoField {.deprecated.} = ## Initialize ProtoField with string. result = ProtoField(kind: Length, index: index) if len(value) > 0: result.vbuffer = newSeq[byte](len(value)) copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) -proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} = +proc initProtoField*(index: int, + value: ProtoBuffer): ProtoField {.deprecated, inline.} = ## Initialize ProtoField with nested message stored in ``value``. ## ## Note: This procedure performs shallow copy of ``value`` sequence. @@ -127,6 +155,13 @@ proc initProtoBuffer*(data: seq[byte], offset = 0, result.offset = offset result.options = options +proc initProtoBuffer*(data: openarray[byte], offset = 0, + options: set[ProtoFlags] = {}): ProtoBuffer = + ## Initialize ProtoBuffer with copy of ``data``. + result.buffer = @data + result.offset = offset + result.options = options + proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = ## Initialize ProtoBuffer with new sequence of capacity ``cap``. result.buffer = newSeqOfCap[byte](128) @@ -138,16 +173,134 @@ proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = result.offset = 10 elif {WithUint32LeLength, WithUint32BeLength} * options != {}: # Our buffer will start from position 4, so we can store length of buffer - # in [0, 9]. + # in [0, 3]. result.buffer.setLen(4) result.offset = 4 -proc write*(pb: var ProtoBuffer, field: ProtoField) = +proc write*[T: ProtoScalar](pb: var ProtoBuffer, + field: int, value: T) = + checkFieldNumber(field) + var length = 0 + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) + + vsizeof(value) + let header = ProtoFieldKind.Varint + elif T is float32: + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) + + sizeof(T) + let header = ProtoFieldKind.Fixed32 + elif T is float64: + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) + + sizeof(T) + let header = ProtoFieldKind.Fixed64 + + pb.buffer.setLen(len(pb.buffer) + flength) + + let hres = PB.putUVarint(pb.toOpenArray(), length, + getProtoHeader(field, header)) + doAssert(hres.isOk()) + pb.offset += length + when (T is uint64) or (T is uint32) or (T is uint): + let vres = PB.putUVarint(pb.toOpenArray(), length, value) + doAssert(vres.isOk()) + pb.offset += length + elif (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + let vres = putSVarint(pb.toOpenArray(), length, value) + doAssert(vres.isOk()) + pb.offset += length + elif T is float32: + doAssert(pb.isEnough(sizeof(T))) + let u32 = cast[uint32](value) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE() + pb.offset += sizeof(T) + elif T is float64: + doAssert(pb.isEnough(sizeof(T))) + let u64 = cast[uint64](value) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE() + pb.offset += sizeof(T) + +proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int, + value: openarray[T]) = + checkFieldNumber(field) + var length = 0 + let dlength = + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + var res = 0 + for item in value: + res += vsizeof(item) + res + elif (T is float32) or (T is float64): + len(value) * sizeof(T) + + let header = getProtoHeader(field, ProtoFieldKind.Length) + let flength = vsizeof(header) + vsizeof(uint64(dlength)) + dlength + pb.buffer.setLen(len(pb.buffer) + flength) + let hres = PB.putUVarint(pb.toOpenArray(), length, header) + doAssert(hres.isOk()) + pb.offset += length + length = 0 + let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(dlength)) + doAssert(lres.isOk()) + pb.offset += length + for item in value: + when (T is uint64) or (T is uint32) or (T is uint): + length = 0 + let vres = PB.putUVarint(pb.toOpenArray(), length, item) + doAssert(vres.isOk()) + pb.offset += length + elif (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + length = 0 + let vres = PB.putSVarint(pb.toOpenArray(), length, item) + doAssert(vres.isOk()) + pb.offset += length + elif T is float32: + doAssert(pb.isEnough(sizeof(T))) + let u32 = cast[uint32](item) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE() + pb.offset += sizeof(T) + elif T is float64: + doAssert(pb.isEnough(sizeof(T))) + let u64 = cast[uint64](item) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE() + pb.offset += sizeof(T) + +proc write*[T: byte|char](pb: var ProtoBuffer, field: int, + value: openarray[T]) = + checkFieldNumber(field) + var length = 0 + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) + + vsizeof(uint64(len(value))) + len(value) + pb.buffer.setLen(len(pb.buffer) + flength) + let hres = PB.putUVarint(pb.toOpenArray(), length, + getProtoHeader(field, ProtoFieldKind.Length)) + doAssert(hres.isOk()) + pb.offset += length + let lres = PB.putUVarint(pb.toOpenArray(), length, + uint64(len(value))) + doAssert(lres.isOk()) + pb.offset += length + if len(value) > 0: + doAssert(pb.isEnough(len(value))) + copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value)) + pb.offset += len(value) + +proc write*(pb: var ProtoBuffer, field: int, value: ProtoBuffer) {.inline.} = + ## Encode Protobuf's sub-message ``value`` and store it to protobuf's buffer + ## ``pb`` with field number ``field``. + write(pb, field, value.buffer) + +proc write*(pb: var ProtoBuffer, field: ProtoField) {.deprecated.} = ## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``. var length = 0 var res: VarintResult[void] pb.buffer.setLen(len(pb.buffer) + vsizeof(field)) - res = PB.putUVarint(pb.toOpenArray(), length, protoHeader(field)) + res = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field)) doAssert(res.isOk()) pb.offset += length case field.kind @@ -199,31 +352,440 @@ proc finish*(pb: var ProtoBuffer) = pb.offset = pos elif WithUint32BeLength in pb.options: let size = uint(len(pb.buffer) - 4) - pb.buffer[0] = byte((size shr 24) and 0xFF'u) - pb.buffer[1] = byte((size shr 16) and 0xFF'u) - pb.buffer[2] = byte((size shr 8) and 0xFF'u) - pb.buffer[3] = byte(size and 0xFF'u) + pb.buffer[0 ..< 4] = toBytesBE(uint32(size)) pb.offset = 4 elif WithUint32LeLength in pb.options: let size = uint(len(pb.buffer) - 4) - pb.buffer[0] = byte(size and 0xFF'u) - pb.buffer[1] = byte((size shr 8) and 0xFF'u) - pb.buffer[2] = byte((size shr 16) and 0xFF'u) - pb.buffer[3] = byte((size shr 24) and 0xFF'u) + pb.buffer[0 ..< 4] = toBytesLE(uint32(size)) pb.offset = 4 else: pb.offset = 0 +proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool = + var length = 0 + var hdr = 0'u64 + if PB.getUVarint(data.toOpenArray(), length, hdr).isOk(): + let index = uint64(hdr shr 3) + let wire = hdr and 0x07 + if wire in SupportedWireTypes: + data.offset += length + header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire)) + true + else: + false + else: + false + +proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool = + case header.wire + of ProtoFieldKind.Varint: + var length = 0 + var value = 0'u64 + if PB.getUVarint(data.toOpenArray(), length, value).isOk(): + data.offset += length + true + else: + false + of ProtoFieldKind.Fixed32: + if data.isEnough(sizeof(uint32)): + data.offset += sizeof(uint32) + true + else: + false + of ProtoFieldKind.Fixed64: + if data.isEnough(sizeof(uint64)): + data.offset += sizeof(uint64) + true + else: + false + of ProtoFieldKind.Length: + var length = 0 + var bsize = 0'u64 + if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): + data.offset += length + if bsize <= uint64(MaxMessageSize): + if data.isEnough(int(bsize)): + data.offset += int(bsize) + true + else: + false + else: + false + else: + false + of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup: + false + +proc getValue[T: ProtoScalar](data: var ProtoBuffer, + header: ProtoHeader, + outval: var T): ProtoResult = + when (T is uint64) or (T is uint32) or (T is uint): + doAssert(header.wire == ProtoFieldKind.Varint) + var length = 0 + var value = T(0) + if PB.getUVarint(data.toOpenArray(), length, value).isOk(): + data.offset += length + outval = value + ProtoResult.NoError + else: + ProtoResult.VarintDecodeError + elif (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + doAssert(header.wire == ProtoFieldKind.Varint) + var length = 0 + var value = T(0) + if getSVarint(data.toOpenArray(), length, value).isOk(): + data.offset += length + outval = value + ProtoResult.NoError + else: + ProtoResult.VarintDecodeError + elif T is float32: + doAssert(header.wire == ProtoFieldKind.Fixed32) + if data.isEnough(sizeof(float32)): + outval = cast[float32](fromBytesLE(uint32, data.toOpenArray())) + data.offset += sizeof(float32) + ProtoResult.NoError + else: + ProtoResult.MessageIncompleteError + elif T is float64: + doAssert(header.wire == ProtoFieldKind.Fixed64) + if data.isEnough(sizeof(float64)): + outval = cast[float64](fromBytesLE(uint64, data.toOpenArray())) + data.offset += sizeof(float64) + ProtoResult.NoError + else: + ProtoResult.MessageIncompleteError + +proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader, + outBytes: var openarray[T], + outLength: var int): ProtoResult = + doAssert(header.wire == ProtoFieldKind.Length) + var length = 0 + var bsize = 0'u64 + + outLength = 0 + if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): + data.offset += length + if bsize <= uint64(MaxMessageSize): + if data.isEnough(int(bsize)): + outLength = int(bsize) + if len(outBytes) >= int(bsize): + if bsize > 0'u64: + copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) + data.offset += int(bsize) + ProtoResult.NoError + else: + # Buffer overflow should not be critical failure + data.offset += int(bsize) + ProtoResult.BufferOverflowError + else: + ProtoResult.MessageIncompleteError + else: + ProtoResult.MessageSizeTooBigError + else: + ProtoResult.VarintDecodeError + +proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader, + outBytes: var T): ProtoResult = + doAssert(header.wire == ProtoFieldKind.Length) + var length = 0 + var bsize = 0'u64 + outBytes.setLen(0) + + if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): + data.offset += length + if bsize <= uint64(MaxMessageSize): + if data.isEnough(int(bsize)): + outBytes.setLen(bsize) + if bsize > 0'u64: + copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) + data.offset += int(bsize) + ProtoResult.NoError + else: + ProtoResult.MessageIncompleteError + else: + ProtoResult.MessageSizeTooBigError + else: + ProtoResult.VarintDecodeError + +proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int, + output: var T): bool = + checkFieldNumber(field) + var value: T + var res = false + var pb = data + output = T(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output = T(0) + return false + let wireCheck = + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + header.wire == ProtoFieldKind.Varint + elif T is float32: + header.wire == ProtoFieldKind.Fixed32 + elif T is float64: + header.wire == ProtoFieldKind.Fixed64 + if header.index == uint64(field): + if wireCheck: + let r = getValue(pb, header, value) + case r + of ProtoResult.NoError: + res = true + output = value + else: + return false + else: + # We are ignoring wire types different from what we expect, because it + # is how `protoc` is working. + if not(skipValue(pb, header)): + output = T(0) + return false + else: + if not(skipValue(pb, header)): + output = T(0) + return false + res + +proc getField*[T: byte|char](data: ProtoBuffer, field: int, + output: var openarray[T], + outlen: var int): bool = + checkFieldNumber(field) + var pb = data + var res = false + + outlen = 0 + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + if len(output) > 0: + zeroMem(addr output[0], len(output)) + outlen = 0 + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + let r = getValue(pb, header, output, outlen) + case r + of ProtoResult.NoError: + res = true + of ProtoResult.BufferOverflowError: + # Buffer overflow error is not critical error, we still can get + # field values with proper size. + discard + else: + if len(output) > 0: + zeroMem(addr output[0], len(output)) + return false + else: + # We are ignoring wire types different from ProtoFieldKind.Length, + # because it is how `protoc` is working. + if not(skipValue(pb, header)): + if len(output) > 0: + zeroMem(addr output[0], len(output)) + outlen = 0 + return false + else: + if not(skipValue(pb, header)): + if len(output) > 0: + zeroMem(addr output[0], len(output)) + outlen = 0 + return false + + res + +proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int, + output: var T): bool = + checkFieldNumber(field) + var res = false + var pb = data + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + let r = getValue(pb, header, output) + case r + of ProtoResult.NoError: + res = true + of ProtoResult.BufferOverflowError: + # Buffer overflow error is not critical error, we still can get + # field values with proper size. + discard + else: + output.setLen(0) + return false + else: + # We are ignoring wire types different from ProtoFieldKind.Length, + # because it is how `protoc` is working. + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + res + +proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {. + inline.} = + var buffer: seq[byte] + if pb.getField(field, buffer): + output = initProtoBuffer(buffer) + true + else: + false + +proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int, + output: var seq[T]): bool = + checkFieldNumber(field) + var pb = data + output.setLen(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + var item: T + let r = getValue(pb, header, item) + case r + of ProtoResult.NoError: + output.add(item) + else: + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + if len(output) > 0: + true + else: + false + +proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer, + field: int, + output: var seq[T]): bool = + checkFieldNumber(field) + var pb = data + output.setLen(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32, + ProtoFieldKind.Fixed64}: + var item: T + let r = getValue(pb, header, item) + case r + of ProtoResult.NoError: + output.add(item) + else: + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + if len(output) > 0: + true + else: + false + +proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int, + output: var seq[T]): bool = + checkFieldNumber(field) + var pb = data + output.setLen(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + var arritem: seq[byte] + let rarr = getValue(pb, header, arritem) + case rarr + of ProtoResult.NoError: + var pbarr = initProtoBuffer(arritem) + let itemHeader = + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + ProtoHeader(wire: ProtoFieldKind.Varint) + elif T is float32: + ProtoHeader(wire: ProtoFieldKind.Fixed32) + elif T is float64: + ProtoHeader(wire: ProtoFieldKind.Fixed64) + while not(pbarr.isEmpty()): + var item: T + let res = getValue(pbarr, itemHeader, item) + case res + of ProtoResult.NoError: + output.add(item) + else: + output.setLen(0) + return false + else: + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + if len(output) > 0: + true + else: + false + proc getVarintValue*(data: var ProtoBuffer, field: int, - value: var SomeVarint): int = + value: var SomeVarint): int {.deprecated.} = ## Get value of `Varint` type. var length = 0 var header = 0'u64 var soffset = data.offset - if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk(): + if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), + length, header).isOk(): data.offset += length - if header == protoHeader(field, Varint): + if header == getProtoHeader(field, Varint): if not data.isEmpty(): when type(value) is int32 or type(value) is int64 or type(value) is int: let res = getSVarint(data.toOpenArray(), length, value) @@ -237,7 +799,7 @@ proc getVarintValue*(data: var ProtoBuffer, field: int, data.offset = soffset proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, - buffer: var T): int = + buffer: var T): int {.deprecated.} = ## Get value of `Length` type. var length = 0 var header = 0'u64 @@ -245,10 +807,12 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, var soffset = data.offset result = -1 buffer.setLen(0) - if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk(): + if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), + length, header).isOk(): data.offset += length - if header == protoHeader(field, Length): - if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, ssize).isOk(): + if header == getProtoHeader(field, Length): + if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), + length, ssize).isOk(): data.offset += length if ssize <= MaxMessageSize and data.isEnough(int(ssize)): buffer.setLen(ssize) @@ -262,16 +826,16 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, data.offset = soffset proc getBytes*(data: var ProtoBuffer, field: int, - buffer: var seq[byte]): int {.inline.} = + buffer: var seq[byte]): int {.deprecated, inline.} = ## Get value of `Length` type as bytes. result = getLengthValue(data, field, buffer) proc getString*(data: var ProtoBuffer, field: int, - buffer: var string): int {.inline.} = + buffer: var string): int {.deprecated, inline.} = ## Get value of `Length` type as string. result = getLengthValue(data, field, buffer) -proc enterSubmessage*(pb: var ProtoBuffer): int = +proc enterSubmessage*(pb: var ProtoBuffer): int {.deprecated.} = ## Processes protobuf's sub-message and adjust internal offset to enter ## inside of sub-message. Returns field index of sub-message field or ## ``0`` on error. @@ -280,10 +844,12 @@ proc enterSubmessage*(pb: var ProtoBuffer): int = var msize = 0'u64 var soffset = pb.offset - if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, header).isOk(): + if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), + length, header).isOk(): pb.offset += length if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length): - if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, msize).isOk(): + if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), + length, msize).isOk(): pb.offset += length if msize <= MaxMessageSize and pb.isEnough(int(msize)): pb.length = int(msize) @@ -292,7 +858,7 @@ proc enterSubmessage*(pb: var ProtoBuffer): int = # Restore offset on error pb.offset = soffset -proc skipSubmessage*(pb: var ProtoBuffer) = +proc skipSubmessage*(pb: var ProtoBuffer) {.deprecated.} = ## Skip current protobuf's sub-message and adjust internal offset to the ## end of sub-message. doAssert(pb.length != 0) diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index 735d740af..a998ea983 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -47,61 +47,49 @@ type proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer = result = initProtoBuffer() - result.write(initProtoField(1, peerInfo.publicKey.get().getBytes().tryGet())) + result.write(1, peerInfo.publicKey.get().getBytes().tryGet()) for ma in peerInfo.addrs: - result.write(initProtoField(2, ma.data.buffer)) + result.write(2, ma.data.buffer) for proto in peerInfo.protocols: - result.write(initProtoField(3, proto)) + result.write(3, proto) - result.write(initProtoField(4, observedAddr.data.buffer)) + result.write(4, observedAddr.data.buffer) let protoVersion = ProtoVersion - result.write(initProtoField(5, protoVersion)) + result.write(5, protoVersion) let agentVersion = AgentVersion - result.write(initProtoField(6, agentVersion)) + result.write(6, agentVersion) result.finish() proc decodeMsg*(buf: seq[byte]): IdentifyInfo = var pb = initProtoBuffer(buf) - result.pubKey = none(PublicKey) var pubKey: PublicKey - if pb.getValue(1, pubKey) > 0: + if pb.getField(1, pubKey): trace "read public key from message", pubKey = ($pubKey).shortLog result.pubKey = some(pubKey) - result.addrs = newSeq[MultiAddress]() - var address = newSeq[byte]() - while pb.getBytes(2, address) > 0: - if len(address) != 0: - var copyaddr = address - var ma = MultiAddress.init(copyaddr).tryGet() - result.addrs.add(ma) - trace "read address bytes from message", address = ma - address.setLen(0) + if pb.getRepeatedField(2, result.addrs): + trace "read addresses from message", addresses = result.addrs - var proto = "" - while pb.getString(3, proto) > 0: - trace "read proto from message", proto = proto - result.protos.add(proto) - proto = "" + if pb.getRepeatedField(3, result.protos): + trace "read protos from message", protocols = result.protos - var observableAddr = newSeq[byte]() - if pb.getBytes(4, observableAddr) > 0: # attempt to read the observed addr - var ma = MultiAddress.init(observableAddr).tryGet() - trace "read observedAddr from message", address = ma - result.observedAddr = some(ma) + var observableAddr: MultiAddress + if pb.getField(4, observableAddr): + trace "read observableAddr from message", address = observableAddr + result.observedAddr = some(observableAddr) var protoVersion = "" - if pb.getString(5, protoVersion) > 0: + if pb.getField(5, protoVersion): trace "read protoVersion from message", protoVersion = protoVersion result.protoVersion = some(protoVersion) var agentVersion = "" - if pb.getString(6, agentVersion) > 0: + if pb.getField(6, agentVersion): trace "read agentVersion from message", agentVersion = agentVersion result.agentVersion = some(agentVersion) diff --git a/libp2p/protocols/pubsub/rpc/message.nim b/libp2p/protocols/pubsub/rpc/message.nim index d203035d4..9ff941853 100644 --- a/libp2p/protocols/pubsub/rpc/message.nim +++ b/libp2p/protocols/pubsub/rpc/message.nim @@ -32,9 +32,7 @@ func defaultMsgIdProvider*(m: Message): string = byteutils.toHex(m.seqno) & m.fromPeer.pretty proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} = - var buff = initProtoBuffer() - encodeMessage(msg, buff) - p.privateKey.sign(PubSubPrefix & buff.buffer).tryGet().getBytes() + p.privateKey.sign(PubSubPrefix & encodeMessage(msg)).tryGet().getBytes() proc verify*(m: Message, p: PeerInfo): bool = if m.signature.len > 0 and m.key.len > 0: @@ -42,14 +40,11 @@ proc verify*(m: Message, p: PeerInfo): bool = msg.signature = @[] msg.key = @[] - var buff = initProtoBuffer() - encodeMessage(msg, buff) - var remote: Signature var key: PublicKey if remote.init(m.signature) and key.init(m.key): trace "verifying signature", remoteSignature = remote - result = remote.verify(PubSubPrefix & buff.buffer, key) + result = remote.verify(PubSubPrefix & encodeMessage(msg), key) if result: libp2p_pubsub_sig_verify_success.inc() diff --git a/libp2p/protocols/pubsub/rpc/protobuf.nim b/libp2p/protocols/pubsub/rpc/protobuf.nim index 922546b20..c5a3eb309 100644 --- a/libp2p/protocols/pubsub/rpc/protobuf.nim +++ b/libp2p/protocols/pubsub/rpc/protobuf.nim @@ -14,265 +14,247 @@ import messages, ../../../utility, ../../../protobuf/minprotobuf -proc encodeGraft*(graft: ControlGraft, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, graft.topicID)) +proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) = + var ipb = initProtoBuffer() + ipb.write(1, graft.topicID) + ipb.finish() + pb.write(field, ipb) -proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} = - trace "decoding graft msg", buffer = pb.buffer.shortLog - while true: - var topic: string - if pb.getString(1, topic) < 0: - break +proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) = + var ipb = initProtoBuffer() + ipb.write(1, prune.topicID) + ipb.finish() + pb.write(field, ipb) - trace "read topic field from graft msg", topicID = topic - result.add(ControlGraft(topicID: topic)) - -proc encodePrune*(prune: ControlPrune, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, prune.topicID)) - -proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} = - trace "decoding prune msg" - while true: - var topic: string - if pb.getString(1, topic) < 0: - break - - trace "read topic field from prune msg", topicID = topic - result.add(ControlPrune(topicID: topic)) - -proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, ihave.topicID)) +proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) = + var ipb = initProtoBuffer() + ipb.write(1, ihave.topicID) for mid in ihave.messageIDs: - pb.write(initProtoField(2, mid)) + ipb.write(2, mid) + ipb.finish() + pb.write(field, ipb) -proc decodeIHave*(pb: var ProtoBuffer): seq[ControlIHave] {.gcsafe.} = - trace "decoding ihave msg" - - while true: - var control: ControlIHave - if pb.getString(1, control.topicID) < 0: - trace "topic field missing from ihave msg" - break - - trace "read topic field", topicID = control.topicID - - while true: - var mid: string - if pb.getString(2, mid) < 0: - break - trace "read messageID field", mid = mid - control.messageIDs.add(mid) - - result.add(control) - -proc encodeIWant*(iwant: ControlIWant, pb: var ProtoBuffer) {.gcsafe.} = +proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) = + var ipb = initProtoBuffer() for mid in iwant.messageIDs: - pb.write(initProtoField(1, mid)) + ipb.write(1, mid) + if len(ipb.buffer) > 0: + ipb.finish() + pb.write(field, ipb) -proc decodeIWant*(pb: var ProtoBuffer): seq[ControlIWant] {.gcsafe.} = - trace "decoding iwant msg" +proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) = + var ipb = initProtoBuffer() + for ihave in control.ihave: + ipb.write(1, ihave) + for iwant in control.iwant: + ipb.write(2, iwant) + for graft in control.graft: + ipb.write(3, graft) + for prune in control.prune: + ipb.write(4, prune) + if len(ipb.buffer) > 0: + ipb.finish() + pb.write(field, ipb) - var control: ControlIWant - while true: - var mid: string - if pb.getString(1, mid) < 0: - break - control.messageIDs.add(mid) - trace "read messageID field", mid = mid - result.add(control) - -proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} = - if control.ihave.len > 0: - var ihave = initProtoBuffer() - for h in control.ihave: - h.encodeIHave(ihave) - - # write messages to protobuf - if ihave.buffer.len > 0: - ihave.finish() - pb.write(initProtoField(1, ihave)) - - if control.iwant.len > 0: - var iwant = initProtoBuffer() - for w in control.iwant: - w.encodeIWant(iwant) - - # write messages to protobuf - if iwant.buffer.len > 0: - iwant.finish() - pb.write(initProtoField(2, iwant)) - - if control.graft.len > 0: - var graft = initProtoBuffer() - for g in control.graft: - g.encodeGraft(graft) - - # write messages to protobuf - if graft.buffer.len > 0: - graft.finish() - pb.write(initProtoField(3, graft)) - - if control.prune.len > 0: - var prune = initProtoBuffer() - for p in control.prune: - p.encodePrune(prune) - - # write messages to protobuf - if prune.buffer.len > 0: - prune.finish() - pb.write(initProtoField(4, prune)) - -proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} = - trace "decoding control submessage" - var control: ControlMessage - while true: - var field = pb.enterSubMessage() - trace "processing submessage", field = field - case field: - of 0: - trace "no submessage found in Control msg" - break - of 1: - control.ihave &= pb.decodeIHave() - of 2: - control.iwant &= pb.decodeIWant() - of 3: - control.graft &= pb.decodeGraft() - of 4: - control.prune &= pb.decodePrune() - else: - raise newException(CatchableError, "message type not recognized") - - if result.isNone: - result = some(control) - -proc encodeSubs*(subs: SubOpts, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, subs.subscribe)) - pb.write(initProtoField(2, subs.topic)) - -proc decodeSubs*(pb: var ProtoBuffer): seq[SubOpts] {.gcsafe.} = - while true: - var subOpt: SubOpts - var subscr: uint - discard pb.getVarintValue(1, subscr) - subOpt.subscribe = cast[bool](subscr) - trace "read subscribe field", subscribe = subOpt.subscribe - - if pb.getString(2, subOpt.topic) < 0: - break - trace "read subscribe field", topicName = subOpt.topic - - result.add(subOpt) - - trace "got subscriptions", subscriptions = result - -proc encodeMessage*(msg: Message, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, msg.fromPeer.getBytes())) - pb.write(initProtoField(2, msg.data)) - pb.write(initProtoField(3, msg.seqno)) - - for t in msg.topicIDs: - pb.write(initProtoField(4, t)) - - if msg.signature.len > 0: - pb.write(initProtoField(5, msg.signature)) - - if msg.key.len > 0: - pb.write(initProtoField(6, msg.key)) +proc write*(pb: var ProtoBuffer, field: int, subs: SubOpts) = + var ipb = initProtoBuffer() + ipb.write(1, uint64(subs.subscribe)) + ipb.write(2, subs.topic) + ipb.finish() + pb.write(field, ipb) +proc encodeMessage*(msg: Message): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, msg.fromPeer) + pb.write(2, msg.data) + pb.write(3, msg.seqno) + for topic in msg.topicIDs: + pb.write(4, topic) + if len(msg.signature) > 0: + pb.write(5, msg.signature) + if len(msg.key) > 0: + pb.write(6, msg.key) pb.finish() + pb.buffer -proc decodeMessages*(pb: var ProtoBuffer): seq[Message] {.gcsafe.} = - # TODO: which of this fields are really optional? - while true: - var msg: Message - var fromPeer: seq[byte] - if pb.getBytes(1, fromPeer) < 0: - break - try: - msg.fromPeer = PeerID.init(fromPeer).tryGet() - except CatchableError as err: - debug "Invalid fromPeer in message", msg = err.msg - break +proc write*(pb: var ProtoBuffer, field: int, msg: Message) = + pb.write(field, encodeMessage(msg)) - trace "read message field", fromPeer = msg.fromPeer.pretty +proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} = + trace "decodeGraft: decoding message" + var control = ControlGraft() + var topicId: string + if pb.getField(1, topicId): + control.topicId = topicId + trace "decodeGraft: read topicId", topic_id = topicId + else: + trace "decodeGraft: topicId is missing" + control - if pb.getBytes(2, msg.data) < 0: - break - trace "read message field", data = msg.data.shortLog +proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} = + trace "decodePrune: decoding message" + var control = ControlPrune() + var topicId: string + if pb.getField(1, topicId): + control.topicId = topicId + trace "decodePrune: read topicId", topic_id = topicId + else: + trace "decodePrune: topicId is missing" + control - if pb.getBytes(3, msg.seqno) < 0: - break - trace "read message field", seqno = msg.seqno.shortLog +proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} = + trace "decodeIHave: decoding message" + var control = ControlIHave() + var topicId: string + if pb.getField(1, topicId): + control.topicId = topicId + trace "decodeIHave: read topicId", topic_id = topicId + else: + trace "decodeIHave: topicId is missing" + if pb.getRepeatedField(2, control.messageIDs): + trace "decodeIHave: read messageIDs", message_ids = control.messageIDs + else: + trace "decodeIHave: no messageIDs" + control - var topic: string - while true: - if pb.getString(4, topic) < 0: - break - msg.topicIDs.add(topic) - trace "read message field", topicName = topic - topic = "" +proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} = + trace "decodeIWant: decoding message" + var control = ControlIWant() + if pb.getRepeatedField(1, control.messageIDs): + trace "decodeIWant: read messageIDs", message_ids = control.messageIDs + else: + trace "decodeIWant: no messageIDs" - discard pb.getBytes(5, msg.signature) - trace "read message field", signature = msg.signature.shortLog +proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} = + trace "decodeControl: decoding message" + var buffer: seq[byte] + if pb.getField(3, buffer): + var control: ControlMessage + var cpb = initProtoBuffer(buffer) + var ihavepbs: seq[seq[byte]] + var iwantpbs: seq[seq[byte]] + var graftpbs: seq[seq[byte]] + var prunepbs: seq[seq[byte]] - discard pb.getBytes(6, msg.key) - trace "read message field", key = msg.key.shortLog + discard cpb.getRepeatedField(1, ihavepbs) + discard cpb.getRepeatedField(2, iwantpbs) + discard cpb.getRepeatedField(3, graftpbs) + discard cpb.getRepeatedField(4, prunepbs) - result.add(msg) + for item in ihavepbs: + control.ihave.add(decodeIHave(initProtoBuffer(item))) + for item in iwantpbs: + control.iwant.add(decodeIWant(initProtoBuffer(item))) + for item in graftpbs: + control.graft.add(decodeGraft(initProtoBuffer(item))) + for item in prunepbs: + control.prune.add(decodePrune(initProtoBuffer(item))) -proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = - result = initProtoBuffer() - trace "encoding msg: ", msg = msg.shortLog + trace "decodeControl: " + some(control) + else: + none[ControlMessage]() - if msg.subscriptions.len > 0: - for s in msg.subscriptions: - var subs = initProtoBuffer() - encodeSubs(s, subs) +proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} = + trace "decodeSubscription: decoding message" + var subflag: uint64 + var sub = SubOpts() + if pb.getField(1, subflag): + sub.subscribe = bool(subflag) + trace "decodeSubscription: read subscribe", subscribe = subflag + else: + trace "decodeSubscription: subscribe is missing" + if pb.getField(2, sub.topic): + trace "decodeSubscription: read topic", topic = sub.topic + else: + trace "decodeSubscription: topic is missing" - # write subscriptions to protobuf - if subs.buffer.len > 0: - subs.finish() - result.write(initProtoField(1, subs)) + sub - if msg.messages.len > 0: - var messages = initProtoBuffer() - for m in msg.messages: - encodeMessage(m, messages) +proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} = + trace "decodeSubscriptions: decoding message" + var subpbs: seq[seq[byte]] + var subs: seq[SubOpts] + if pb.getRepeatedField(1, subpbs): + trace "decodeSubscriptions: read subscriptions", count = len(subpbs) + for item in subpbs: + let sub = decodeSubscription(initProtoBuffer(item)) + subs.add(sub) - # write messages to protobuf - if messages.buffer.len > 0: - messages.finish() - result.write(initProtoField(2, messages)) + if len(subs) == 0: + trace "decodeSubscription: no subscriptions found" - if msg.control.isSome: - var control = initProtoBuffer() - msg.control.get.encodeControl(control) + subs - # write messages to protobuf - if control.buffer.len > 0: - control.finish() - result.write(initProtoField(3, control)) +proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} = + trace "decodeMessage: decoding message" + var msg: Message + if pb.getField(1, msg.fromPeer): + trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty() + else: + trace "decodeMessage: fromPeer is missing" - if result.buffer.len > 0: - result.finish() + if pb.getField(2, msg.data): + trace "decodeMessage: read data", data = msg.data.shortLog() + else: + trace "decodeMessage: data is missing" -proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} = + if pb.getField(3, msg.seqno): + trace "decodeMessage: read seqno", seqno = msg.data.shortLog() + else: + trace "decodeMessage: seqno is missing" + + if pb.getRepeatedField(4, msg.topicIDs): + trace "decodeMessage: read topics", topic_ids = msg.topicIDs + else: + trace "decodeMessage: topics are missing" + + if pb.getField(5, msg.signature): + trace "decodeMessage: read signature", signature = msg.signature.shortLog() + else: + trace "decodeMessage: signature is missing" + + if pb.getField(6, msg.key): + trace "decodeMessage: read public key", key = msg.key.shortLog() + else: + trace "decodeMessage: public key is missing" + + msg + +proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} = + trace "decodeMessages: decoding message" + var msgpbs: seq[seq[byte]] + var msgs: seq[Message] + if pb.getRepeatedField(2, msgpbs): + trace "decodeMessages: read messages", count = len(msgpbs) + for item in msgpbs: + let msg = decodeMessage(initProtoBuffer(item)) + msgs.add(msg) + + if len(msgs) == 0: + trace "decodeMessages: no messages found" + + msgs + +proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer = + trace "encodeRpcMsg: encoding message", msg = msg.shortLog() + var pb = initProtoBuffer() + for item in msg.subscriptions: + pb.write(1, item) + for item in msg.messages: + pb.write(2, item) + if msg.control.isSome(): + pb.write(3, msg.control.get()) + if len(pb.buffer) > 0: + pb.finish() + result = pb + +proc decodeRpcMsg*(msg: seq[byte]): RPCMsg = + trace "decodeRpcMsg: decoding message", msg = msg.shortLog() var pb = initProtoBuffer(msg) + var rpcMsg: RPCMsg + rpcMsg.messages = pb.decodeMessages() + rpcMsg.subscriptions = pb.decodeSubscriptions() + rpcMsg.control = pb.decodeControl() - while true: - # decode SubOpts array - var field = pb.enterSubMessage() - trace "processing submessage", field = field - case field: - of 0: - trace "no submessage found in RPC msg" - break - of 1: - result.subscriptions &= pb.decodeSubs() - of 2: - result.messages &= pb.decodeMessages() - of 3: - result.control = pb.decodeControl() - else: - raise newException(CatchableError, "message type not recognized") + rpcMsg diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index fa14674bb..d5398ccbb 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -430,8 +430,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon var libp2pProof = initProtoBuffer() - libp2pProof.write(initProtoField(1, p.localPublicKey)) - libp2pProof.write(initProtoField(2, signedPayload.getBytes())) + libp2pProof.write(1, p.localPublicKey) + libp2pProof.write(2, signedPayload.getBytes()) # data field also there but not used! libp2pProof.finish() @@ -449,9 +449,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon remoteSig: Signature remoteSigBytes: seq[byte] - if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: + if not(remoteProof.getField(1, remotePubKeyBytes)): raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: + if not(remoteProof.getField(2, remoteSigBytes)): raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") if not remotePubKey.init(remotePubKeyBytes): diff --git a/tests/testminprotobuf.nim b/tests/testminprotobuf.nim new file mode 100644 index 000000000..a4fe7fead --- /dev/null +++ b/tests/testminprotobuf.nim @@ -0,0 +1,677 @@ +## Nim-Libp2p +## Copyright (c) 2018 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import unittest +import ../libp2p/protobuf/minprotobuf +import stew/byteutils, strutils + +when defined(nimHasUsed): {.used.} + +suite "MinProtobuf test suite": + const VarintVectors = [ + "0800", "0801", "08ffffffff07", "08ffffffff0f", "08ffffffffffffffff7f", + "08ffffffffffffffffff01" + ] + + const VarintValues = [ + 0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xFFFF_FFFF'u64, + 0x7FFF_FFFF_FFFF_FFFF'u64, 0xFFFF_FFFF_FFFF_FFFF'u64 + ] + + const Fixed32Vectors = [ + "0d00000000", "0d01000000", "0dffffff7f", "0dddccbbaa", "0dffffffff" + ] + + const Fixed32Values = [ + 0x0'u32, 0x1'u32, 0x7FFF_FFFF'u32, 0xAABB_CCDD'u32, 0xFFFF_FFFF'u32 + ] + + const Fixed64Vectors = [ + "090000000000000000", "090100000000000000", "09ffffff7f00000000", + "09ddccbbaa00000000", "09ffffffff00000000", "09ffffffffffffff7f", + "099988ffeeddccbbaa", "09ffffffffffffffff" + ] + + const Fixed64Values = [ + 0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xAABB_CCDD'u64, 0xFFFF_FFFF'u64, + 0x7FFF_FFFF_FFFF_FFFF'u64, 0xAABB_CCDD_EEFF_8899'u64, + 0xFFFF_FFFF_FFFF_FFFF'u64 + ] + + const LengthVectors = [ + "0a00", "0a0161", "0a026162", "0a0461626364", "0a086162636465666768" + ] + + const LengthValues = [ + "", "a", "ab", "abcd", "abcdefgh" + ] + + ## This vector values was tested with `protoc` and related proto file. + + ## syntax = "proto2"; + ## message testmsg { + ## repeated uint64 d = 1 [packed=true]; + ## repeated uint64 d = 2 [packed=true]; + ## } + const PackedVarintVector = + "0a1f0001ffffffff07ffffffff0fffffffffffffffff7fffffffffffffffffff0112020001" + ## syntax = "proto2"; + ## message testmsg { + ## repeated sfixed32 d = 1 [packed=true]; + ## repeated sfixed32 d = 2 [packed=true]; + ## } + const PackedFixed32Vector = + "0a140000000001000000ffffff7fddccbbaaffffffff12080000000001000000" + ## syntax = "proto2"; + ## message testmsg { + ## repeated sfixed64 d = 1 [packed=true]; + ## repeated sfixed64 d = 2 [packed=true]; + ## } + const PackedFixed64Vector = + """0a4000000000000000000100000000000000ffffff7f00000000ddccbbaa00000000 + ffffffff00000000ffffffffffffff7f9988ffeeddccbbaaffffffffffffffff1210 + 00000000000000000100000000000000""" + + proc getVarintEncodedValue(value: uint64): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getVarintDecodedValue(data: openarray[byte]): uint64 = + var value: uint64 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value) + doAssert(res) + value + + proc getFixed32EncodedValue(value: float32): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getFixed32DecodedValue(data: openarray[byte]): uint32 = + var value: float32 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value) + doAssert(res) + cast[uint32](value) + + proc getFixed64EncodedValue(value: float64): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getFixed64DecodedValue(data: openarray[byte]): uint64 = + var value: float64 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value) + doAssert(res) + cast[uint64](value) + + proc getLengthEncodedValue(value: string): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getLengthEncodedValue(value: seq[byte]): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getLengthDecodedValue(data: openarray[byte]): string = + var value = newString(len(data)) + var valueLen = 0 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value, valueLen) + + doAssert(res) + value.setLen(valueLen) + value + + proc isFullZero[T: byte|char](data: openarray[T]): bool = + for ch in data: + if int(ch) != 0: + return false + return true + + proc corruptHeader(data: var openarray[byte], index: int) = + var values = [3, 4, 6] + data[0] = data[0] and 0xF8'u8 + data[0] = data[0] or byte(values[index mod len(values)]) + + test "[varint] edge values test": + for i in 0 ..< len(VarintValues): + let data = getVarintEncodedValue(VarintValues[i]) + check: + toHex(data) == VarintVectors[i] + getVarintDecodedValue(data) == VarintValues[i] + + test "[varint] mixing many values with same field number test": + for i in 0 ..< len(VarintValues): + var pb = initProtoBuffer() + for k in 0 ..< len(VarintValues): + let index = (i + k + 1) mod len(VarintValues) + pb.write(1, VarintValues[index]) + pb.finish() + check getVarintDecodedValue(pb.buffer) == VarintValues[i] + + test "[varint] incorrect values test": + for i in 0 ..< len(VarintValues): + var value: uint64 + var data = getVarintEncodedValue(VarintValues[i]) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[varint] non-existent field test": + for i in 0 ..< len(VarintValues): + var value: uint64 + var data = getVarintEncodedValue(VarintValues[i]) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value) == false + value == 0'u64 + + test "[varint] corrupted header test": + for i in 0 ..< len(VarintValues): + for k in 0 ..< 3: + var value: uint64 + var data = getVarintEncodedValue(VarintValues[i]) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[varint] empty buffer test": + var value: uint64 + var pb = initProtoBuffer() + check: + pb.getField(1, value) == false + value == 0'u64 + + test "[varint] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, VarintValues[1]) + pb1.write(1, VarintValues[2]) + pb1.write(2, VarintValues[3]) + pb1.write(1, VarintValues[4]) + pb1.write(1, VarintValues[5]) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[uint64] + var fieldarr2: seq[uint64] + var fieldarr3: seq[uint64] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + fieldarr1[0] == VarintValues[1] + fieldarr1[1] == VarintValues[2] + fieldarr1[2] == VarintValues[4] + fieldarr1[3] == VarintValues[5] + fieldarr2[0] == VarintValues[3] + + test "[varint] Repeated packed field test": + var pb1 = initProtoBuffer() + pb1.writePacked(1, VarintValues) + pb1.writePacked(2, VarintValues[0 .. 1]) + pb1.finish() + check: + toHex(pb1.buffer) == PackedVarintVector + + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[uint64] + var fieldarr2: seq[uint64] + var fieldarr3: seq[uint64] + let r1 = pb2.getPackedRepeatedField(1, fieldarr1) + let r2 = pb2.getPackedRepeatedField(2, fieldarr2) + let r3 = pb2.getPackedRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 2 + len(fieldarr1) == 6 + fieldarr1[0] == VarintValues[0] + fieldarr1[1] == VarintValues[1] + fieldarr1[2] == VarintValues[2] + fieldarr1[3] == VarintValues[3] + fieldarr1[4] == VarintValues[4] + fieldarr1[5] == VarintValues[5] + fieldarr2[0] == VarintValues[0] + fieldarr2[1] == VarintValues[1] + + test "[fixed32] edge values test": + for i in 0 ..< len(Fixed32Values): + let data = getFixed32EncodedValue(cast[float32](Fixed32Values[i])) + check: + toHex(data) == Fixed32Vectors[i] + getFixed32DecodedValue(data) == Fixed32Values[i] + + test "[fixed32] mixing many values with same field number test": + for i in 0 ..< len(Fixed32Values): + var pb = initProtoBuffer() + for k in 0 ..< len(Fixed32Values): + let index = (i + k + 1) mod len(Fixed32Values) + pb.write(1, cast[float32](Fixed32Values[index])) + pb.finish() + check getFixed32DecodedValue(pb.buffer) == Fixed32Values[i] + + test "[fixed32] incorrect values test": + for i in 0 ..< len(Fixed32Values): + var value: float32 + var data = getFixed32EncodedValue(float32(Fixed32Values[i])) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed32] non-existent field test": + for i in 0 ..< len(Fixed32Values): + var value: float32 + var data = getFixed32EncodedValue(float32(Fixed32Values[i])) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value) == false + value == float32(0) + + test "[fixed32] corrupted header test": + for i in 0 ..< len(Fixed32Values): + for k in 0 ..< 3: + var value: float32 + var data = getFixed32EncodedValue(float32(Fixed32Values[i])) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed32] empty buffer test": + var value: float32 + var pb = initProtoBuffer() + check: + pb.getField(1, value) == false + value == float32(0) + + test "[fixed32] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, cast[float32](Fixed32Values[0])) + pb1.write(1, cast[float32](Fixed32Values[1])) + pb1.write(2, cast[float32](Fixed32Values[2])) + pb1.write(1, cast[float32](Fixed32Values[3])) + pb1.write(1, cast[float32](Fixed32Values[4])) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float32] + var fieldarr2: seq[float32] + var fieldarr3: seq[float32] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + cast[uint32](fieldarr1[0]) == Fixed64Values[0] + cast[uint32](fieldarr1[1]) == Fixed64Values[1] + cast[uint32](fieldarr1[2]) == Fixed64Values[3] + cast[uint32](fieldarr1[3]) == Fixed64Values[4] + cast[uint32](fieldarr2[0]) == Fixed64Values[2] + + test "[fixed32] Repeated packed field test": + var pb1 = initProtoBuffer() + var values = newSeq[float32](len(Fixed32Values)) + for i in 0 ..< len(values): + values[i] = cast[float32](Fixed32Values[i]) + pb1.writePacked(1, values) + pb1.writePacked(2, values[0 .. 1]) + pb1.finish() + check: + toHex(pb1.buffer) == PackedFixed32Vector + + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float32] + var fieldarr2: seq[float32] + var fieldarr3: seq[float32] + let r1 = pb2.getPackedRepeatedField(1, fieldarr1) + let r2 = pb2.getPackedRepeatedField(2, fieldarr2) + let r3 = pb2.getPackedRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 2 + len(fieldarr1) == 5 + cast[uint32](fieldarr1[0]) == Fixed32Values[0] + cast[uint32](fieldarr1[1]) == Fixed32Values[1] + cast[uint32](fieldarr1[2]) == Fixed32Values[2] + cast[uint32](fieldarr1[3]) == Fixed32Values[3] + cast[uint32](fieldarr1[4]) == Fixed32Values[4] + cast[uint32](fieldarr2[0]) == Fixed32Values[0] + cast[uint32](fieldarr2[1]) == Fixed32Values[1] + + test "[fixed64] edge values test": + for i in 0 ..< len(Fixed64Values): + let data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + check: + toHex(data) == Fixed64Vectors[i] + getFixed64DecodedValue(data) == Fixed64Values[i] + + test "[fixed64] mixing many values with same field number test": + for i in 0 ..< len(Fixed64Values): + var pb = initProtoBuffer() + for k in 0 ..< len(Fixed64Values): + let index = (i + k + 1) mod len(Fixed64Values) + pb.write(1, cast[float64](Fixed64Values[index])) + pb.finish() + check getFixed64DecodedValue(pb.buffer) == Fixed64Values[i] + + test "[fixed64] incorrect values test": + for i in 0 ..< len(Fixed64Values): + var value: float32 + var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed64] non-existent field test": + for i in 0 ..< len(Fixed64Values): + var value: float64 + var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value) == false + value == float64(0) + + test "[fixed64] corrupted header test": + for i in 0 ..< len(Fixed64Values): + for k in 0 ..< 3: + var value: float64 + var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed64] empty buffer test": + var value: float64 + var pb = initProtoBuffer() + check: + pb.getField(1, value) == false + value == float64(0) + + test "[fixed64] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, cast[float64](Fixed64Values[2])) + pb1.write(1, cast[float64](Fixed64Values[3])) + pb1.write(2, cast[float64](Fixed64Values[4])) + pb1.write(1, cast[float64](Fixed64Values[5])) + pb1.write(1, cast[float64](Fixed64Values[6])) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float64] + var fieldarr2: seq[float64] + var fieldarr3: seq[float64] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + cast[uint64](fieldarr1[0]) == Fixed64Values[2] + cast[uint64](fieldarr1[1]) == Fixed64Values[3] + cast[uint64](fieldarr1[2]) == Fixed64Values[5] + cast[uint64](fieldarr1[3]) == Fixed64Values[6] + cast[uint64](fieldarr2[0]) == Fixed64Values[4] + + test "[fixed64] Repeated packed field test": + var pb1 = initProtoBuffer() + var values = newSeq[float64](len(Fixed64Values)) + for i in 0 ..< len(values): + values[i] = cast[float64](Fixed64Values[i]) + pb1.writePacked(1, values) + pb1.writePacked(2, values[0 .. 1]) + pb1.finish() + let expect = PackedFixed64Vector.multiReplace(("\n", ""), (" ", "")) + check: + toHex(pb1.buffer) == expect + + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float64] + var fieldarr2: seq[float64] + var fieldarr3: seq[float64] + let r1 = pb2.getPackedRepeatedField(1, fieldarr1) + let r2 = pb2.getPackedRepeatedField(2, fieldarr2) + let r3 = pb2.getPackedRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 2 + len(fieldarr1) == 8 + cast[uint64](fieldarr1[0]) == Fixed64Values[0] + cast[uint64](fieldarr1[1]) == Fixed64Values[1] + cast[uint64](fieldarr1[2]) == Fixed64Values[2] + cast[uint64](fieldarr1[3]) == Fixed64Values[3] + cast[uint64](fieldarr1[4]) == Fixed64Values[4] + cast[uint64](fieldarr1[5]) == Fixed64Values[5] + cast[uint64](fieldarr1[6]) == Fixed64Values[6] + cast[uint64](fieldarr1[7]) == Fixed64Values[7] + cast[uint64](fieldarr2[0]) == Fixed64Values[0] + cast[uint64](fieldarr2[1]) == Fixed64Values[1] + + test "[length] edge values test": + for i in 0 ..< len(LengthValues): + let data1 = getLengthEncodedValue(LengthValues[i]) + let data2 = getLengthEncodedValue(cast[seq[byte]](LengthValues[i])) + check: + toHex(data1) == LengthVectors[i] + toHex(data2) == LengthVectors[i] + check: + getLengthDecodedValue(data1) == LengthValues[i] + getLengthDecodedValue(data2) == LengthValues[i] + + test "[length] mixing many values with same field number test": + for i in 0 ..< len(LengthValues): + var pb1 = initProtoBuffer() + var pb2 = initProtoBuffer() + for k in 0 ..< len(LengthValues): + let index = (i + k + 1) mod len(LengthValues) + pb1.write(1, LengthValues[index]) + pb2.write(1, cast[seq[byte]](LengthValues[index])) + pb1.finish() + pb2.finish() + check getLengthDecodedValue(pb1.buffer) == LengthValues[i] + check getLengthDecodedValue(pb2.buffer) == LengthValues[i] + + test "[length] incorrect values test": + for i in 0 ..< len(LengthValues): + var value = newSeq[byte](len(LengthValues[i])) + var valueLen = 0 + var data = getLengthEncodedValue(LengthValues[i]) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value, valueLen) == false + + test "[length] non-existent field test": + for i in 0 ..< len(LengthValues): + var value = newSeq[byte](len(LengthValues[i])) + var valueLen = 0 + var data = getLengthEncodedValue(LengthValues[i]) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value, valueLen) == false + valueLen == 0 + + test "[length] corrupted header test": + for i in 0 ..< len(LengthValues): + for k in 0 ..< 3: + var value = newSeq[byte](len(LengthValues[i])) + var valueLen = 0 + var data = getLengthEncodedValue(LengthValues[i]) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value, valueLen) == false + + test "[length] empty buffer test": + var value = newSeq[byte](len(LengthValues[0])) + var valueLen = 0 + var pb = initProtoBuffer() + check: + pb.getField(1, value, valueLen) == false + valueLen == 0 + + test "[length] buffer overflow test": + for i in 1 ..< len(LengthValues): + let data = getLengthEncodedValue(LengthValues[i]) + + var value = newString(len(LengthValues[i]) - 1) + var valueLen = 0 + var pb = initProtoBuffer(data) + check: + pb.getField(1, value, valueLen) == false + valueLen == len(LengthValues[i]) + isFullZero(value) == true + + test "[length] mix of buffer overflow and normal fields test": + var pb1 = initProtoBuffer() + pb1.write(1, "TEST10") + pb1.write(1, "TEST20") + pb1.write(1, "TEST") + pb1.write(1, "TEST30") + pb1.write(1, "SOME") + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var value = newString(4) + var valueLen = 0 + check: + pb2.getField(1, value, valueLen) == true + value == "SOME" + + test "[length] too big message test": + var pb1 = initProtoBuffer() + var bigString = newString(MaxMessageSize + 1) + + for i in 0 ..< len(bigString): + bigString[i] = 'A' + pb1.write(1, bigString) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var value = newString(MaxMessageSize + 1) + var valueLen = 0 + check: + pb2.getField(1, value, valueLen) == false + + test "[length] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, "TEST1") + pb1.write(1, "TEST2") + pb1.write(2, "TEST5") + pb1.write(1, "TEST3") + pb1.write(1, "TEST4") + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[seq[byte]] + var fieldarr2: seq[seq[byte]] + var fieldarr3: seq[seq[byte]] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + cast[string](fieldarr1[0]) == "TEST1" + cast[string](fieldarr1[1]) == "TEST2" + cast[string](fieldarr1[2]) == "TEST3" + cast[string](fieldarr1[3]) == "TEST4" + cast[string](fieldarr2[0]) == "TEST5" + + test "Different value types in one message with same field number test": + proc getEncodedValue(): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, VarintValues[1]) + pb.write(2, cast[float32](Fixed32Values[1])) + pb.write(3, cast[float64](Fixed64Values[1])) + pb.write(4, LengthValues[1]) + + pb.write(1, VarintValues[2]) + pb.write(2, cast[float32](Fixed32Values[2])) + pb.write(3, cast[float64](Fixed64Values[2])) + pb.write(4, LengthValues[2]) + + pb.write(1, cast[float32](Fixed32Values[3])) + pb.write(2, cast[float64](Fixed64Values[3])) + pb.write(3, LengthValues[3]) + pb.write(4, VarintValues[3]) + + pb.write(1, cast[float64](Fixed64Values[4])) + pb.write(2, LengthValues[4]) + pb.write(3, VarintValues[4]) + pb.write(4, cast[float32](Fixed32Values[4])) + + pb.write(1, VarintValues[1]) + pb.write(2, cast[float32](Fixed32Values[1])) + pb.write(3, cast[float64](Fixed64Values[1])) + pb.write(4, LengthValues[1]) + pb.finish() + pb.buffer + + let msg = getEncodedValue() + let pb = initProtoBuffer(msg) + var varintValue: uint64 + var fixed32Value: float32 + var fixed64Value: float64 + var lengthValue = newString(10) + var lengthSize: int + + check: + pb.getField(1, varintValue) == true + pb.getField(2, fixed32Value) == true + pb.getField(3, fixed64Value) == true + pb.getField(4, lengthValue, lengthSize) == true + + lengthValue.setLen(lengthSize) + + check: + varintValue == VarintValues[1] + cast[uint32](fixed32Value) == Fixed32Values[1] + cast[uint64](fixed64Value) == Fixed64Values[1] + lengthValue == LengthValues[1] diff --git a/tests/testnative.nim b/tests/testnative.nim index 53e4a244e..9f75ffec6 100644 --- a/tests/testnative.nim +++ b/tests/testnative.nim @@ -1,4 +1,5 @@ import testvarint, + testminprotobuf, teststreamseq import testrsa,