[WIP] Minprotobuf refactoring (#259)

* Minprotobuf initial commit

* Fix noise.

* Add signed integers support.
Add checks for field number value.
Remove some casts.

* Fix compile errors.

* Fix comments and constants.
This commit is contained in:
Eugene Kabanov 2020-07-13 15:43:07 +03:00 committed by GitHub
parent 181cf73ca7
commit efb952f18b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1705 additions and 413 deletions

View File

@ -222,8 +222,8 @@ proc toBytes*(key: PrivateKey, data: var openarray[byte]): CryptoResult[int] =
## ##
## Returns number of bytes (octets) needed to store private key ``key``. ## Returns number of bytes (octets) needed to store private key ``key``.
var msg = initProtoBuffer() var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme))) msg.write(1, uint64(key.scheme))
msg.write(initProtoField(2, ? key.getRawBytes())) msg.write(2, ? key.getRawBytes())
msg.finish() msg.finish()
var blen = len(msg.buffer) var blen = len(msg.buffer)
if len(data) >= blen: if len(data) >= blen:
@ -236,8 +236,8 @@ proc toBytes*(key: PublicKey, data: var openarray[byte]): CryptoResult[int] =
## ##
## Returns number of bytes (octets) needed to store public key ``key``. ## Returns number of bytes (octets) needed to store public key ``key``.
var msg = initProtoBuffer() var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme))) msg.write(1, uint64(key.scheme))
msg.write(initProtoField(2, ? key.getRawBytes())) msg.write(2, ? key.getRawBytes())
msg.finish() msg.finish()
var blen = len(msg.buffer) var blen = len(msg.buffer)
if len(data) >= blen and blen > 0: if len(data) >= blen and blen > 0:
@ -256,8 +256,8 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] =
## Return private key ``key`` in binary form (using libp2p's protobuf ## Return private key ``key`` in binary form (using libp2p's protobuf
## serialization). ## serialization).
var msg = initProtoBuffer() var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme))) msg.write(1, uint64(key.scheme))
msg.write(initProtoField(2, ? key.getRawBytes())) msg.write(2, ? key.getRawBytes())
msg.finish() msg.finish()
ok(msg.buffer) ok(msg.buffer)
@ -265,8 +265,8 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] =
## Return public key ``key`` in binary form (using libp2p's protobuf ## Return public key ``key`` in binary form (using libp2p's protobuf
## serialization). ## serialization).
var msg = initProtoBuffer() var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme))) msg.write(1, uint64(key.scheme))
msg.write(initProtoField(2, ? key.getRawBytes())) msg.write(2, ? key.getRawBytes())
msg.finish() msg.finish()
ok(msg.buffer) ok(msg.buffer)
@ -283,33 +283,32 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool =
var buffer: seq[byte] var buffer: seq[byte]
if len(data) > 0: if len(data) > 0:
var pb = initProtoBuffer(@data) var pb = initProtoBuffer(@data)
if pb.getVarintValue(1, id) != 0: if pb.getField(1, id) and pb.getField(2, buffer):
if pb.getBytes(2, buffer) != 0: if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0:
if cast[int8](id) in SupportedSchemesInt: var scheme = cast[PKScheme](cast[int8](id))
var scheme = cast[PKScheme](cast[int8](id)) when key is PrivateKey:
when key is PrivateKey: var nkey = PrivateKey(scheme: scheme)
var nkey = PrivateKey(scheme: scheme) else:
else: var nkey = PublicKey(scheme: scheme)
var nkey = PublicKey(scheme: scheme) case scheme:
case scheme: of PKScheme.RSA:
of PKScheme.RSA: if init(nkey.rsakey, buffer).isOk:
if init(nkey.rsakey, buffer).isOk: key = nkey
key = nkey return true
return true of PKScheme.Ed25519:
of PKScheme.Ed25519: if init(nkey.edkey, buffer):
if init(nkey.edkey, buffer): key = nkey
key = nkey return true
return true of PKScheme.ECDSA:
of PKScheme.ECDSA: if init(nkey.eckey, buffer).isOk:
if init(nkey.eckey, buffer).isOk: key = nkey
key = nkey return true
return true of PKScheme.Secp256k1:
of PKScheme.Secp256k1: if init(nkey.skkey, buffer).isOk:
if init(nkey.skkey, buffer).isOk: key = nkey
key = nkey return true
return true else:
else: return false
return false
proc init*(sig: var Signature, data: openarray[byte]): bool = proc init*(sig: var Signature, data: openarray[byte]): bool =
## Initialize signature ``sig`` from raw binary form. ## Initialize signature ``sig`` from raw binary form.
@ -727,11 +726,11 @@ proc createProposal*(nonce, pubkey: openarray[byte],
## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and ## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and
## comma-delimeted list of supported hashes ``hashes``. ## comma-delimeted list of supported hashes ``hashes``.
var msg = initProtoBuffer({WithUint32BeLength}) var msg = initProtoBuffer({WithUint32BeLength})
msg.write(initProtoField(1, nonce)) msg.write(1, nonce)
msg.write(initProtoField(2, pubkey)) msg.write(2, pubkey)
msg.write(initProtoField(3, exchanges)) msg.write(3, exchanges)
msg.write(initProtoField(4, ciphers)) msg.write(4, ciphers)
msg.write(initProtoField(5, hashes)) msg.write(5, hashes)
msg.finish() msg.finish()
shallowCopy(result, msg.buffer) shallowCopy(result, msg.buffer)
@ -744,19 +743,16 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
## ##
## Procedure returns ``true`` on success and ``false`` on error. ## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message) var pb = initProtoBuffer(message)
if pb.getLengthValue(1, nonce) != -1 and pb.getField(1, nonce) and pb.getField(2, pubkey) and
pb.getLengthValue(2, pubkey) != -1 and pb.getField(3, exchanges) and pb.getField(4, ciphers) and
pb.getLengthValue(3, exchanges) != -1 and pb.getField(5, hashes)
pb.getLengthValue(4, ciphers) != -1 and
pb.getLengthValue(5, hashes) != -1:
result = true
proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] = proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] =
## Create SecIO exchange message using ephemeral public key ``epubkey`` and ## Create SecIO exchange message using ephemeral public key ``epubkey`` and
## signature of proposal blocks ``signature``. ## signature of proposal blocks ``signature``.
var msg = initProtoBuffer({WithUint32BeLength}) var msg = initProtoBuffer({WithUint32BeLength})
msg.write(initProtoField(1, epubkey)) msg.write(1, epubkey)
msg.write(initProtoField(2, signature)) msg.write(2, signature)
msg.finish() msg.finish()
shallowCopy(result, msg.buffer) shallowCopy(result, msg.buffer)
@ -767,9 +763,7 @@ proc decodeExchange*(message: seq[byte],
## ##
## Procedure returns ``true`` on success and ``false`` on error. ## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message) var pb = initProtoBuffer(message)
if pb.getLengthValue(1, pubkey) != -1 and pb.getField(1, pubkey) and pb.getField(2, signature)
pb.getLengthValue(2, signature) != -1:
result = true
## Serialization/Deserialization helpers ## Serialization/Deserialization helpers
@ -788,22 +782,27 @@ proc write*(vb: var VBuffer, sig: PrivateKey) {.
## Write Signature value ``sig`` to buffer ``vb``. ## Write Signature value ``sig`` to buffer ``vb``.
vb.writeSeq(sig.getBytes().tryGet()) vb.writeSeq(sig.getBytes().tryGet())
proc initProtoField*(index: int, pubkey: PublicKey): ProtoField {. proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int,
raises: [Defect, ResultError[CryptoError]].} = key: T) {.
## Initialize ProtoField with PublicKey ``pubkey``. inline, raises: [Defect, ResultError[CryptoError]].} =
result = initProtoField(index, pubkey.getBytes().tryGet()) write(pb, field, key.getBytes().tryGet())
proc initProtoField*(index: int, seckey: PrivateKey): ProtoField {. proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {.
raises: [Defect, ResultError[CryptoError]].} = inline, raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PrivateKey ``seckey``. write(pb, field, sig.getBytes())
result = initProtoField(index, seckey.getBytes().tryGet())
proc initProtoField*(index: int, sig: Signature): ProtoField = proc initProtoField*(index: int, key: PublicKey|PrivateKey): ProtoField {.
deprecated, raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PublicKey/PrivateKey ``key``.
result = initProtoField(index, key.getBytes().tryGet())
proc initProtoField*(index: int, sig: Signature): ProtoField {.deprecated.} =
## Initialize ProtoField with Signature ``sig``. ## Initialize ProtoField with Signature ``sig``.
result = initProtoField(index, sig.getBytes()) result = initProtoField(index, sig.getBytes())
proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int = proc getValue*[T: PublicKey|PrivateKey](data: var ProtoBuffer, field: int,
## Read ``PublicKey`` from ProtoBuf's message and validate it. value: var T): int {.deprecated.} =
## Read PublicKey/PrivateKey from ProtoBuf's message and validate it.
var buf: seq[byte] var buf: seq[byte]
var key: PublicKey var key: PublicKey
result = getLengthValue(data, field, buf) result = getLengthValue(data, field, buf)
@ -813,18 +812,8 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
else: else:
value = key value = key
proc getValue*(data: var ProtoBuffer, field: int, value: var PrivateKey): int = proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {.
## Read ``PrivateKey`` from ProtoBuf's message and validate it. deprecated.} =
var buf: seq[byte]
var key: PrivateKey
result = getLengthValue(data, field, buf)
if result > 0:
if not key.init(buf):
result = -1
else:
value = key
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
## Read ``Signature`` from ProtoBuf's message and validate it. ## Read ``Signature`` from ProtoBuf's message and validate it.
var buf: seq[byte] var buf: seq[byte]
var sig: Signature var sig: Signature
@ -834,3 +823,30 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
result = -1 result = -1
else: else:
value = sig value = sig
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
value: var T): bool =
var buffer: seq[byte]
var key: T
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if key.init(buffer):
value = key
true
else:
false
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool =
var buffer: seq[byte]
var sig: Signature
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if sig.init(buffer):
value = sig
true
else:
false

View File

@ -14,9 +14,10 @@
import nativesockets import nativesockets
import tables, strutils, stew/shims/net import tables, strutils, stew/shims/net
import chronos import chronos
import multicodec, multihash, multibase, transcoder, vbuffer, peerid import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
protobuf/minprotobuf
import stew/[base58, base32, endians2, results] import stew/[base58, base32, endians2, results]
export results export results, minprotobuf, vbuffer
type type
MAKind* = enum MAKind* = enum
@ -477,7 +478,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] =
else: else:
ok($(proto.mcodec)) ok($(proto.mcodec))
proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] = proc protoArgument*(ma: MultiAddress,
value: var openarray[byte]): MaResult[int] =
## Returns MultiAddress ``ma`` protocol argument value. ## Returns MultiAddress ``ma`` protocol argument value.
## ##
## If current MultiAddress do not have argument value, then result will be ## If current MultiAddress do not have argument value, then result will be
@ -496,8 +498,8 @@ proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int]
var res: int var res: int
if proto.kind == Fixed: if proto.kind == Fixed:
res = proto.size res = proto.size
if len(value) >= res and if len(value) >= res and
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size: vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
err("multiaddress: Decoding protocol error") err("multiaddress: Decoding protocol error")
else: else:
ok(res) ok(res)
@ -580,7 +582,8 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] =
let proto = CodeAddresses.getOrDefault(MultiCodec(header)) let proto = CodeAddresses.getOrDefault(MultiCodec(header))
if proto.kind == None: if proto.kind == None:
yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'") yield err(MaResult[MultiAddress], "Unsupported protocol '" &
$header & "'")
elif proto.kind == Fixed: elif proto.kind == Fixed:
data.setLen(proto.size) data.setLen(proto.size)
@ -609,7 +612,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} =
return ok(true) return ok(true)
ok(false) ok(false)
proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} = proc `[]`*(ma: MultiAddress,
codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
## Returns partial MultiAddress with MultiCodec ``codec`` and present in ## Returns partial MultiAddress with MultiCodec ``codec`` and present in
## MultiAddress ``ma``. ## MultiAddress ``ma``.
for item in ma.items: for item in ma.items:
@ -634,7 +638,8 @@ proc toString*(value: MultiAddress): MaResult[string] =
return err("multiaddress: Unsupported protocol '" & $header & "'") return err("multiaddress: Unsupported protocol '" & $header & "'")
if proto.kind in {Fixed, Length, Path}: if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.bufferToString): if isNil(proto.coder.bufferToString):
return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder") return err("multiaddress: Missing protocol '" & $(proto.mcodec) &
"' coder")
if not proto.coder.bufferToString(vb.data, part): if not proto.coder.bufferToString(vb.data, part):
return err("multiaddress: Decoding protocol error") return err("multiaddress: Decoding protocol error")
parts.add($(proto.mcodec)) parts.add($(proto.mcodec))
@ -729,12 +734,14 @@ proc init*(
of None: of None:
raiseAssert "None checked above" raiseAssert "None checked above"
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerID): MaResult[MultiAddress] {.inline.} = proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: PeerID): MaResult[MultiAddress] {.inline.} =
## Initialize MultiAddress object from protocol id ``protocol`` and peer id ## Initialize MultiAddress object from protocol id ``protocol`` and peer id
## ``value``. ## ``value``.
init(mtype, protocol, value.data) init(mtype, protocol, value.data)
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int): MaResult[MultiAddress] = proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: int): MaResult[MultiAddress] =
## Initialize MultiAddress object from protocol id ``protocol`` and integer ## Initialize MultiAddress object from protocol id ``protocol`` and integer
## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``, ## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``,
## ``dccp`` and ``sctp`` MultiAddresses. ## ``dccp`` and ``sctp`` MultiAddresses.
@ -759,7 +766,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} =
if mc != InvalidMultiCodec: if mc != InvalidMultiCodec:
result = CodeAddresses.getOrDefault(mc) result = CodeAddresses.getOrDefault(mc)
proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] = proc init*(mtype: typedesc[MultiAddress],
value: string): MaResult[MultiAddress] =
## Initialize MultiAddress object from string representation ``value``. ## Initialize MultiAddress object from string representation ``value``.
var parts = value.trimRight('/').split('/') var parts = value.trimRight('/').split('/')
if len(parts[0]) != 0: if len(parts[0]) != 0:
@ -776,7 +784,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
else: else:
if proto.kind in {Fixed, Length, Path}: if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.stringToBuffer): if isNil(proto.coder.stringToBuffer):
return err("multiaddress: Missing protocol '" & part & "' transcoder") return err("multiaddress: Missing protocol '" &
part & "' transcoder")
if offset + 1 >= len(parts): if offset + 1 >= len(parts):
return err("multiaddress: Missing protocol '" & part & "' argument") return err("multiaddress: Missing protocol '" & part & "' argument")
@ -785,14 +794,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.write(proto.mcodec) res.data.write(proto.mcodec)
let res = proto.coder.stringToBuffer(parts[offset + 1], res.data) let res = proto.coder.stringToBuffer(parts[offset + 1], res.data)
if not res: if not res:
return err("multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`") return err("multiaddress: Error encoding `" & part & "/" &
parts[offset + 1] & "`")
offset += 2 offset += 2
elif proto.kind == Path: elif proto.kind == Path:
var path = "/" & (parts[(offset + 1)..^1].join("/")) var path = "/" & (parts[(offset + 1)..^1].join("/"))
res.data.write(proto.mcodec) res.data.write(proto.mcodec)
if not proto.coder.stringToBuffer(path, res.data): if not proto.coder.stringToBuffer(path, res.data):
return err("multiaddress: Error encoding `" & part & "/" & path & "`") return err("multiaddress: Error encoding `" & part & "/" &
path & "`")
break break
elif proto.kind == Marker: elif proto.kind == Marker:
@ -801,8 +812,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.finish() res.data.finish()
ok(res) ok(res)
proc init*(mtype: typedesc[MultiAddress],
proc init*(mtype: typedesc[MultiAddress], data: openarray[byte]): MaResult[MultiAddress] = data: openarray[byte]): MaResult[MultiAddress] =
## Initialize MultiAddress with array of bytes ``data``. ## Initialize MultiAddress with array of bytes ``data``.
if len(data) == 0: if len(data) == 0:
err("multiaddress: Address could not be empty!") err("multiaddress: Address could not be empty!")
@ -836,10 +847,12 @@ proc init*(mtype: typedesc[MultiAddress],
var data = initVBuffer() var data = initVBuffer()
data.write(familyProto.mcodec) data.write(familyProto.mcodec)
var written = familyProto.coder.stringToBuffer($address, data) var written = familyProto.coder.stringToBuffer($address, data)
doAssert written, "Merely writing a string to a buffer should always be possible" doAssert written,
"Merely writing a string to a buffer should always be possible"
data.write(protoProto.mcodec) data.write(protoProto.mcodec)
written = protoProto.coder.stringToBuffer($port, data) written = protoProto.coder.stringToBuffer($port, data)
doAssert written, "Merely writing a string to a buffer should always be possible" doAssert written,
"Merely writing a string to a buffer should always be possible"
data.finish() data.finish()
MultiAddress(data: data) MultiAddress(data: data)
@ -890,14 +903,16 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] =
else: else:
ok() ok()
proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [Defect, ResultError[string]].} = proc `&`*(m1, m2: MultiAddress): MultiAddress {.
raises: [Defect, ResultError[string]].} =
## Concatenates two addresses ``m1`` and ``m2``, and returns result. ## Concatenates two addresses ``m1`` and ``m2``, and returns result.
## ##
## This procedure performs validation of concatenated result and can raise ## This procedure performs validation of concatenated result and can raise
## exception on error. ## exception on error.
concat(m1, m2).tryGet() concat(m1, m2).tryGet()
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [Defect, ResultError[string]].} = proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.
raises: [Defect, ResultError[string]].} =
## Concatenates two addresses ``m1`` and ``m2``. ## Concatenates two addresses ``m1`` and ``m2``.
## ##
## This procedure performs validation of concatenated result and can raise ## This procedure performs validation of concatenated result and can raise
@ -1005,3 +1020,36 @@ proc `$`*(pat: MaPattern): string =
result = "(" & sub.join("|") & ")" result = "(" & sub.join("|") & ")"
elif pat.operator == Eq: elif pat.operator == Eq:
result = $pat.value result = $pat.value
proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
write(pb, field, value.data.buffer)
proc getField*(pb: var ProtoBuffer, field: int,
value: var MultiAddress): bool {.inline.} =
var buffer: seq[byte]
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
let ma = MultiAddress.init(buffer)
if ma.isOk():
value = ma.get()
true
else:
false
proc getRepeatedField*(pb: var ProtoBuffer, field: int,
value: var seq[MultiAddress]): bool {.inline.} =
var items: seq[seq[byte]]
value.setLen(0)
if not(getRepeatedField(pb, field, items)):
return false
if len(items) == 0:
return true
for item in items:
let ma = MultiAddress.init(item)
if ma.isOk():
value.add(ma.get())
else:
value.setLen(0)
return false

View File

@ -200,11 +200,12 @@ proc write*(vb: var VBuffer, pid: PeerID) {.inline.} =
## Write PeerID value ``peerid`` to buffer ``vb``. ## Write PeerID value ``peerid`` to buffer ``vb``.
vb.writeSeq(pid.data) vb.writeSeq(pid.data)
proc initProtoField*(index: int, pid: PeerID): ProtoField = proc initProtoField*(index: int, pid: PeerID): ProtoField {.deprecated.} =
## Initialize ProtoField with PeerID ``value``. ## Initialize ProtoField with PeerID ``value``.
result = initProtoField(index, pid.data) result = initProtoField(index, pid.data)
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int = proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int {.
deprecated.} =
## Read ``PeerID`` from ProtoBuf's message and validate it. ## Read ``PeerID`` from ProtoBuf's message and validate it.
var pid: PeerID var pid: PeerID
result = getLengthValue(data, field, pid.data) result = getLengthValue(data, field, pid.data)
@ -213,3 +214,21 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
result = -1 result = -1
else: else:
value = pid value = pid
proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) =
## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
write(pb, field, pid.data)
proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool =
## Read ``PeerID`` from ProtoBuf's message and validate it
var buffer: seq[byte]
var peerId: PeerID
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if peerId.init(buffer):
pid = peerId
true
else:
false

View File

@ -11,7 +11,7 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import ../varint import ../varint, stew/endians2
const const
MaxMessageSize* = 1'u shl 22 MaxMessageSize* = 1'u shl 22
@ -32,10 +32,14 @@ type
offset*: int offset*: int
length*: int length*: int
ProtoHeader* = object
wire*: ProtoFieldKind
index*: uint64
ProtoField* = object ProtoField* = object
## Protobuf's message field representation object ## Protobuf's message field representation object
index: int index*: int
case kind: ProtoFieldKind case kind*: ProtoFieldKind
of Varint: of Varint:
vint*: uint64 vint*: uint64
of Fixed64: of Fixed64:
@ -47,13 +51,35 @@ type
of StartGroup, EndGroup: of StartGroup, EndGroup:
discard discard
template protoHeader*(index: int, wire: ProtoFieldKind): uint = ProtoResult {.pure.} = enum
## Get protobuf's field header integer for ``index`` and ``wire``. VarintDecodeError,
((uint(index) shl 3) or cast[uint](wire)) MessageIncompleteError,
BufferOverflowError,
MessageSizeTooBigError,
NoError
template protoHeader*(field: ProtoField): uint = ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
hint | hint32 | hint64 | float32 | float64
const
SupportedWireTypes* = {
int(ProtoFieldKind.Varint),
int(ProtoFieldKind.Fixed64),
int(ProtoFieldKind.Length),
int(ProtoFieldKind.Fixed32)
}
template checkFieldNumber*(i: int) =
doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999),
"Incorrect or reserved field number")
template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 =
## Get protobuf's field header integer for ``index`` and ``wire``.
((uint64(index) shl 3) or uint64(wire))
template getProtoHeader*(field: ProtoField): uint64 =
## Get protobuf's field header integer for ``field``. ## Get protobuf's field header integer for ``field``.
((uint(field.index) shl 3) or cast[uint](field.kind)) ((uint64(field.index) shl 3) or uint64(field.kind))
template toOpenArray*(pb: ProtoBuffer): untyped = template toOpenArray*(pb: ProtoBuffer): untyped =
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1) toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)
@ -72,20 +98,20 @@ template getLen*(pb: ProtoBuffer): int =
proc vsizeof*(field: ProtoField): int {.inline.} = proc vsizeof*(field: ProtoField): int {.inline.} =
## Returns number of bytes required to store protobuf's field ``field``. ## Returns number of bytes required to store protobuf's field ``field``.
result = vsizeof(protoHeader(field))
case field.kind case field.kind
of ProtoFieldKind.Varint: of ProtoFieldKind.Varint:
result += vsizeof(field.vint) vsizeof(getProtoHeader(field)) + vsizeof(field.vint)
of ProtoFieldKind.Fixed64: of ProtoFieldKind.Fixed64:
result += sizeof(field.vfloat64) vsizeof(getProtoHeader(field)) + sizeof(field.vfloat64)
of ProtoFieldKind.Fixed32: of ProtoFieldKind.Fixed32:
result += sizeof(field.vfloat32) vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32)
of ProtoFieldKind.Length: of ProtoFieldKind.Length:
result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer) vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) +
len(field.vbuffer)
else: else:
discard 0
proc initProtoField*(index: int, value: SomeVarint): ProtoField = proc initProtoField*(index: int, value: SomeVarint): ProtoField {.deprecated.} =
## Initialize ProtoField with integer value. ## Initialize ProtoField with integer value.
result = ProtoField(kind: Varint, index: index) result = ProtoField(kind: Varint, index: index)
when type(value) is uint64: when type(value) is uint64:
@ -93,26 +119,28 @@ proc initProtoField*(index: int, value: SomeVarint): ProtoField =
else: else:
result.vint = cast[uint64](value) result.vint = cast[uint64](value)
proc initProtoField*(index: int, value: bool): ProtoField = proc initProtoField*(index: int, value: bool): ProtoField {.deprecated.} =
## Initialize ProtoField with integer value. ## Initialize ProtoField with integer value.
result = ProtoField(kind: Varint, index: index) result = ProtoField(kind: Varint, index: index)
result.vint = byte(value) result.vint = byte(value)
proc initProtoField*(index: int, value: openarray[byte]): ProtoField = proc initProtoField*(index: int,
value: openarray[byte]): ProtoField {.deprecated.} =
## Initialize ProtoField with bytes array. ## Initialize ProtoField with bytes array.
result = ProtoField(kind: Length, index: index) result = ProtoField(kind: Length, index: index)
if len(value) > 0: if len(value) > 0:
result.vbuffer = newSeq[byte](len(value)) result.vbuffer = newSeq[byte](len(value))
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
proc initProtoField*(index: int, value: string): ProtoField = proc initProtoField*(index: int, value: string): ProtoField {.deprecated.} =
## Initialize ProtoField with string. ## Initialize ProtoField with string.
result = ProtoField(kind: Length, index: index) result = ProtoField(kind: Length, index: index)
if len(value) > 0: if len(value) > 0:
result.vbuffer = newSeq[byte](len(value)) result.vbuffer = newSeq[byte](len(value))
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} = proc initProtoField*(index: int,
value: ProtoBuffer): ProtoField {.deprecated, inline.} =
## Initialize ProtoField with nested message stored in ``value``. ## Initialize ProtoField with nested message stored in ``value``.
## ##
## Note: This procedure performs shallow copy of ``value`` sequence. ## Note: This procedure performs shallow copy of ``value`` sequence.
@ -127,6 +155,13 @@ proc initProtoBuffer*(data: seq[byte], offset = 0,
result.offset = offset result.offset = offset
result.options = options result.options = options
proc initProtoBuffer*(data: openarray[byte], offset = 0,
options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with copy of ``data``.
result.buffer = @data
result.offset = offset
result.options = options
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with new sequence of capacity ``cap``. ## Initialize ProtoBuffer with new sequence of capacity ``cap``.
result.buffer = newSeqOfCap[byte](128) result.buffer = newSeqOfCap[byte](128)
@ -138,16 +173,134 @@ proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
result.offset = 10 result.offset = 10
elif {WithUint32LeLength, WithUint32BeLength} * options != {}: elif {WithUint32LeLength, WithUint32BeLength} * options != {}:
# Our buffer will start from position 4, so we can store length of buffer # Our buffer will start from position 4, so we can store length of buffer
# in [0, 9]. # in [0, 3].
result.buffer.setLen(4) result.buffer.setLen(4)
result.offset = 4 result.offset = 4
proc write*(pb: var ProtoBuffer, field: ProtoField) = proc write*[T: ProtoScalar](pb: var ProtoBuffer,
field: int, value: T) =
checkFieldNumber(field)
var length = 0
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) +
vsizeof(value)
let header = ProtoFieldKind.Varint
elif T is float32:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) +
sizeof(T)
let header = ProtoFieldKind.Fixed32
elif T is float64:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) +
sizeof(T)
let header = ProtoFieldKind.Fixed64
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, header))
doAssert(hres.isOk())
pb.offset += length
when (T is uint64) or (T is uint32) or (T is uint):
let vres = PB.putUVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let vres = putSVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
let u32 = cast[uint32](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
let u64 = cast[uint64](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int,
value: openarray[T]) =
checkFieldNumber(field)
var length = 0
let dlength =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
var res = 0
for item in value:
res += vsizeof(item)
res
elif (T is float32) or (T is float64):
len(value) * sizeof(T)
let header = getProtoHeader(field, ProtoFieldKind.Length)
let flength = vsizeof(header) + vsizeof(uint64(dlength)) + dlength
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length, header)
doAssert(hres.isOk())
pb.offset += length
length = 0
let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(dlength))
doAssert(lres.isOk())
pb.offset += length
for item in value:
when (T is uint64) or (T is uint32) or (T is uint):
length = 0
let vres = PB.putUVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
length = 0
let vres = PB.putSVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
let u32 = cast[uint32](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
let u64 = cast[uint64](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc write*[T: byte|char](pb: var ProtoBuffer, field: int,
value: openarray[T]) =
checkFieldNumber(field)
var length = 0
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) +
vsizeof(uint64(len(value))) + len(value)
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, ProtoFieldKind.Length))
doAssert(hres.isOk())
pb.offset += length
let lres = PB.putUVarint(pb.toOpenArray(), length,
uint64(len(value)))
doAssert(lres.isOk())
pb.offset += length
if len(value) > 0:
doAssert(pb.isEnough(len(value)))
copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value))
pb.offset += len(value)
proc write*(pb: var ProtoBuffer, field: int, value: ProtoBuffer) {.inline.} =
## Encode Protobuf's sub-message ``value`` and store it to protobuf's buffer
## ``pb`` with field number ``field``.
write(pb, field, value.buffer)
proc write*(pb: var ProtoBuffer, field: ProtoField) {.deprecated.} =
## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``. ## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``.
var length = 0 var length = 0
var res: VarintResult[void] var res: VarintResult[void]
pb.buffer.setLen(len(pb.buffer) + vsizeof(field)) pb.buffer.setLen(len(pb.buffer) + vsizeof(field))
res = PB.putUVarint(pb.toOpenArray(), length, protoHeader(field)) res = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field))
doAssert(res.isOk()) doAssert(res.isOk())
pb.offset += length pb.offset += length
case field.kind case field.kind
@ -199,31 +352,440 @@ proc finish*(pb: var ProtoBuffer) =
pb.offset = pos pb.offset = pos
elif WithUint32BeLength in pb.options: elif WithUint32BeLength in pb.options:
let size = uint(len(pb.buffer) - 4) let size = uint(len(pb.buffer) - 4)
pb.buffer[0] = byte((size shr 24) and 0xFF'u) pb.buffer[0 ..< 4] = toBytesBE(uint32(size))
pb.buffer[1] = byte((size shr 16) and 0xFF'u)
pb.buffer[2] = byte((size shr 8) and 0xFF'u)
pb.buffer[3] = byte(size and 0xFF'u)
pb.offset = 4 pb.offset = 4
elif WithUint32LeLength in pb.options: elif WithUint32LeLength in pb.options:
let size = uint(len(pb.buffer) - 4) let size = uint(len(pb.buffer) - 4)
pb.buffer[0] = byte(size and 0xFF'u) pb.buffer[0 ..< 4] = toBytesLE(uint32(size))
pb.buffer[1] = byte((size shr 8) and 0xFF'u)
pb.buffer[2] = byte((size shr 16) and 0xFF'u)
pb.buffer[3] = byte((size shr 24) and 0xFF'u)
pb.offset = 4 pb.offset = 4
else: else:
pb.offset = 0 pb.offset = 0
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
var length = 0
var hdr = 0'u64
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
let index = uint64(hdr shr 3)
let wire = hdr and 0x07
if wire in SupportedWireTypes:
data.offset += length
header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire))
true
else:
false
else:
false
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
case header.wire
of ProtoFieldKind.Varint:
var length = 0
var value = 0'u64
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
true
else:
false
of ProtoFieldKind.Fixed32:
if data.isEnough(sizeof(uint32)):
data.offset += sizeof(uint32)
true
else:
false
of ProtoFieldKind.Fixed64:
if data.isEnough(sizeof(uint64)):
data.offset += sizeof(uint64)
true
else:
false
of ProtoFieldKind.Length:
var length = 0
var bsize = 0'u64
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
data.offset += int(bsize)
true
else:
false
else:
false
else:
false
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
false
proc getValue[T: ProtoScalar](data: var ProtoBuffer,
header: ProtoHeader,
outval: var T): ProtoResult =
when (T is uint64) or (T is uint32) or (T is uint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
var value = T(0)
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
outval = value
ProtoResult.NoError
else:
ProtoResult.VarintDecodeError
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
var value = T(0)
if getSVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
outval = value
ProtoResult.NoError
else:
ProtoResult.VarintDecodeError
elif T is float32:
doAssert(header.wire == ProtoFieldKind.Fixed32)
if data.isEnough(sizeof(float32)):
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
data.offset += sizeof(float32)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
elif T is float64:
doAssert(header.wire == ProtoFieldKind.Fixed64)
if data.isEnough(sizeof(float64)):
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
data.offset += sizeof(float64)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var openarray[T],
outLength: var int): ProtoResult =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
outLength = 0
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
outLength = int(bsize)
if len(outBytes) >= int(bsize):
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ProtoResult.NoError
else:
# Buffer overflow should not be critical failure
data.offset += int(bsize)
ProtoResult.BufferOverflowError
else:
ProtoResult.MessageIncompleteError
else:
ProtoResult.MessageSizeTooBigError
else:
ProtoResult.VarintDecodeError
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var T): ProtoResult =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
outBytes.setLen(0)
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
outBytes.setLen(bsize)
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
else:
ProtoResult.MessageSizeTooBigError
else:
ProtoResult.VarintDecodeError
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var T): bool =
checkFieldNumber(field)
var value: T
var res = false
var pb = data
output = T(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output = T(0)
return false
let wireCheck =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
header.wire == ProtoFieldKind.Varint
elif T is float32:
header.wire == ProtoFieldKind.Fixed32
elif T is float64:
header.wire == ProtoFieldKind.Fixed64
if header.index == uint64(field):
if wireCheck:
let r = getValue(pb, header, value)
case r
of ProtoResult.NoError:
res = true
output = value
else:
return false
else:
# We are ignoring wire types different from what we expect, because it
# is how `protoc` is working.
if not(skipValue(pb, header)):
output = T(0)
return false
else:
if not(skipValue(pb, header)):
output = T(0)
return false
res
proc getField*[T: byte|char](data: ProtoBuffer, field: int,
output: var openarray[T],
outlen: var int): bool =
checkFieldNumber(field)
var pb = data
var res = false
outlen = 0
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output, outlen)
case r
of ProtoResult.NoError:
res = true
of ProtoResult.BufferOverflowError:
# Buffer overflow error is not critical error, we still can get
# field values with proper size.
discard
else:
if len(output) > 0:
zeroMem(addr output[0], len(output))
return false
else:
# We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working.
if not(skipValue(pb, header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
else:
if not(skipValue(pb, header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
res
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var T): bool =
checkFieldNumber(field)
var res = false
var pb = data
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output)
case r
of ProtoResult.NoError:
res = true
of ProtoResult.BufferOverflowError:
# Buffer overflow error is not critical error, we still can get
# field values with proper size.
discard
else:
output.setLen(0)
return false
else:
# We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working.
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
res
proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {.
inline.} =
var buffer: seq[byte]
if pb.getField(field, buffer):
output = initProtoBuffer(buffer)
true
else:
false
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
var item: T
let r = getValue(pb, header, item)
case r
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer,
field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
ProtoFieldKind.Fixed64}:
var item: T
let r = getValue(pb, header, item)
case r
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
var arritem: seq[byte]
let rarr = getValue(pb, header, arritem)
case rarr
of ProtoResult.NoError:
var pbarr = initProtoBuffer(arritem)
let itemHeader =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
ProtoHeader(wire: ProtoFieldKind.Varint)
elif T is float32:
ProtoHeader(wire: ProtoFieldKind.Fixed32)
elif T is float64:
ProtoHeader(wire: ProtoFieldKind.Fixed64)
while not(pbarr.isEmpty()):
var item: T
let res = getValue(pbarr, itemHeader, item)
case res
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getVarintValue*(data: var ProtoBuffer, field: int, proc getVarintValue*(data: var ProtoBuffer, field: int,
value: var SomeVarint): int = value: var SomeVarint): int {.deprecated.} =
## Get value of `Varint` type. ## Get value of `Varint` type.
var length = 0 var length = 0
var header = 0'u64 var header = 0'u64
var soffset = data.offset var soffset = data.offset
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk(): if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, header).isOk():
data.offset += length data.offset += length
if header == protoHeader(field, Varint): if header == getProtoHeader(field, Varint):
if not data.isEmpty(): if not data.isEmpty():
when type(value) is int32 or type(value) is int64 or type(value) is int: when type(value) is int32 or type(value) is int64 or type(value) is int:
let res = getSVarint(data.toOpenArray(), length, value) let res = getSVarint(data.toOpenArray(), length, value)
@ -237,7 +799,7 @@ proc getVarintValue*(data: var ProtoBuffer, field: int,
data.offset = soffset data.offset = soffset
proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
buffer: var T): int = buffer: var T): int {.deprecated.} =
## Get value of `Length` type. ## Get value of `Length` type.
var length = 0 var length = 0
var header = 0'u64 var header = 0'u64
@ -245,10 +807,12 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
var soffset = data.offset var soffset = data.offset
result = -1 result = -1
buffer.setLen(0) buffer.setLen(0)
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk(): if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, header).isOk():
data.offset += length data.offset += length
if header == protoHeader(field, Length): if header == getProtoHeader(field, Length):
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, ssize).isOk(): if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, ssize).isOk():
data.offset += length data.offset += length
if ssize <= MaxMessageSize and data.isEnough(int(ssize)): if ssize <= MaxMessageSize and data.isEnough(int(ssize)):
buffer.setLen(ssize) buffer.setLen(ssize)
@ -262,16 +826,16 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
data.offset = soffset data.offset = soffset
proc getBytes*(data: var ProtoBuffer, field: int, proc getBytes*(data: var ProtoBuffer, field: int,
buffer: var seq[byte]): int {.inline.} = buffer: var seq[byte]): int {.deprecated, inline.} =
## Get value of `Length` type as bytes. ## Get value of `Length` type as bytes.
result = getLengthValue(data, field, buffer) result = getLengthValue(data, field, buffer)
proc getString*(data: var ProtoBuffer, field: int, proc getString*(data: var ProtoBuffer, field: int,
buffer: var string): int {.inline.} = buffer: var string): int {.deprecated, inline.} =
## Get value of `Length` type as string. ## Get value of `Length` type as string.
result = getLengthValue(data, field, buffer) result = getLengthValue(data, field, buffer)
proc enterSubmessage*(pb: var ProtoBuffer): int = proc enterSubmessage*(pb: var ProtoBuffer): int {.deprecated.} =
## Processes protobuf's sub-message and adjust internal offset to enter ## Processes protobuf's sub-message and adjust internal offset to enter
## inside of sub-message. Returns field index of sub-message field or ## inside of sub-message. Returns field index of sub-message field or
## ``0`` on error. ## ``0`` on error.
@ -280,10 +844,12 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
var msize = 0'u64 var msize = 0'u64
var soffset = pb.offset var soffset = pb.offset
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, header).isOk(): if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
length, header).isOk():
pb.offset += length pb.offset += length
if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length): if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length):
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, msize).isOk(): if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
length, msize).isOk():
pb.offset += length pb.offset += length
if msize <= MaxMessageSize and pb.isEnough(int(msize)): if msize <= MaxMessageSize and pb.isEnough(int(msize)):
pb.length = int(msize) pb.length = int(msize)
@ -292,7 +858,7 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
# Restore offset on error # Restore offset on error
pb.offset = soffset pb.offset = soffset
proc skipSubmessage*(pb: var ProtoBuffer) = proc skipSubmessage*(pb: var ProtoBuffer) {.deprecated.} =
## Skip current protobuf's sub-message and adjust internal offset to the ## Skip current protobuf's sub-message and adjust internal offset to the
## end of sub-message. ## end of sub-message.
doAssert(pb.length != 0) doAssert(pb.length != 0)

View File

@ -47,61 +47,49 @@ type
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer = proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
result.write(initProtoField(1, peerInfo.publicKey.get().getBytes().tryGet())) result.write(1, peerInfo.publicKey.get().getBytes().tryGet())
for ma in peerInfo.addrs: for ma in peerInfo.addrs:
result.write(initProtoField(2, ma.data.buffer)) result.write(2, ma.data.buffer)
for proto in peerInfo.protocols: for proto in peerInfo.protocols:
result.write(initProtoField(3, proto)) result.write(3, proto)
result.write(initProtoField(4, observedAddr.data.buffer)) result.write(4, observedAddr.data.buffer)
let protoVersion = ProtoVersion let protoVersion = ProtoVersion
result.write(initProtoField(5, protoVersion)) result.write(5, protoVersion)
let agentVersion = AgentVersion let agentVersion = AgentVersion
result.write(initProtoField(6, agentVersion)) result.write(6, agentVersion)
result.finish() result.finish()
proc decodeMsg*(buf: seq[byte]): IdentifyInfo = proc decodeMsg*(buf: seq[byte]): IdentifyInfo =
var pb = initProtoBuffer(buf) var pb = initProtoBuffer(buf)
result.pubKey = none(PublicKey)
var pubKey: PublicKey var pubKey: PublicKey
if pb.getValue(1, pubKey) > 0: if pb.getField(1, pubKey):
trace "read public key from message", pubKey = ($pubKey).shortLog trace "read public key from message", pubKey = ($pubKey).shortLog
result.pubKey = some(pubKey) result.pubKey = some(pubKey)
result.addrs = newSeq[MultiAddress]() if pb.getRepeatedField(2, result.addrs):
var address = newSeq[byte]() trace "read addresses from message", addresses = result.addrs
while pb.getBytes(2, address) > 0:
if len(address) != 0:
var copyaddr = address
var ma = MultiAddress.init(copyaddr).tryGet()
result.addrs.add(ma)
trace "read address bytes from message", address = ma
address.setLen(0)
var proto = "" if pb.getRepeatedField(3, result.protos):
while pb.getString(3, proto) > 0: trace "read protos from message", protocols = result.protos
trace "read proto from message", proto = proto
result.protos.add(proto)
proto = ""
var observableAddr = newSeq[byte]() var observableAddr: MultiAddress
if pb.getBytes(4, observableAddr) > 0: # attempt to read the observed addr if pb.getField(4, observableAddr):
var ma = MultiAddress.init(observableAddr).tryGet() trace "read observableAddr from message", address = observableAddr
trace "read observedAddr from message", address = ma result.observedAddr = some(observableAddr)
result.observedAddr = some(ma)
var protoVersion = "" var protoVersion = ""
if pb.getString(5, protoVersion) > 0: if pb.getField(5, protoVersion):
trace "read protoVersion from message", protoVersion = protoVersion trace "read protoVersion from message", protoVersion = protoVersion
result.protoVersion = some(protoVersion) result.protoVersion = some(protoVersion)
var agentVersion = "" var agentVersion = ""
if pb.getString(6, agentVersion) > 0: if pb.getField(6, agentVersion):
trace "read agentVersion from message", agentVersion = agentVersion trace "read agentVersion from message", agentVersion = agentVersion
result.agentVersion = some(agentVersion) result.agentVersion = some(agentVersion)

View File

@ -32,9 +32,7 @@ func defaultMsgIdProvider*(m: Message): string =
byteutils.toHex(m.seqno) & m.fromPeer.pretty byteutils.toHex(m.seqno) & m.fromPeer.pretty
proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} = proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} =
var buff = initProtoBuffer() p.privateKey.sign(PubSubPrefix & encodeMessage(msg)).tryGet().getBytes()
encodeMessage(msg, buff)
p.privateKey.sign(PubSubPrefix & buff.buffer).tryGet().getBytes()
proc verify*(m: Message, p: PeerInfo): bool = proc verify*(m: Message, p: PeerInfo): bool =
if m.signature.len > 0 and m.key.len > 0: if m.signature.len > 0 and m.key.len > 0:
@ -42,14 +40,11 @@ proc verify*(m: Message, p: PeerInfo): bool =
msg.signature = @[] msg.signature = @[]
msg.key = @[] msg.key = @[]
var buff = initProtoBuffer()
encodeMessage(msg, buff)
var remote: Signature var remote: Signature
var key: PublicKey var key: PublicKey
if remote.init(m.signature) and key.init(m.key): if remote.init(m.signature) and key.init(m.key):
trace "verifying signature", remoteSignature = remote trace "verifying signature", remoteSignature = remote
result = remote.verify(PubSubPrefix & buff.buffer, key) result = remote.verify(PubSubPrefix & encodeMessage(msg), key)
if result: if result:
libp2p_pubsub_sig_verify_success.inc() libp2p_pubsub_sig_verify_success.inc()

View File

@ -14,265 +14,247 @@ import messages,
../../../utility, ../../../utility,
../../../protobuf/minprotobuf ../../../protobuf/minprotobuf
proc encodeGraft*(graft: ControlGraft, pb: var ProtoBuffer) {.gcsafe.} = proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) =
pb.write(initProtoField(1, graft.topicID)) var ipb = initProtoBuffer()
ipb.write(1, graft.topicID)
ipb.finish()
pb.write(field, ipb)
proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} = proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) =
trace "decoding graft msg", buffer = pb.buffer.shortLog var ipb = initProtoBuffer()
while true: ipb.write(1, prune.topicID)
var topic: string ipb.finish()
if pb.getString(1, topic) < 0: pb.write(field, ipb)
break
trace "read topic field from graft msg", topicID = topic proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) =
result.add(ControlGraft(topicID: topic)) var ipb = initProtoBuffer()
ipb.write(1, ihave.topicID)
proc encodePrune*(prune: ControlPrune, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, prune.topicID))
proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} =
trace "decoding prune msg"
while true:
var topic: string
if pb.getString(1, topic) < 0:
break
trace "read topic field from prune msg", topicID = topic
result.add(ControlPrune(topicID: topic))
proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, ihave.topicID))
for mid in ihave.messageIDs: for mid in ihave.messageIDs:
pb.write(initProtoField(2, mid)) ipb.write(2, mid)
ipb.finish()
pb.write(field, ipb)
proc decodeIHave*(pb: var ProtoBuffer): seq[ControlIHave] {.gcsafe.} = proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) =
trace "decoding ihave msg" var ipb = initProtoBuffer()
while true:
var control: ControlIHave
if pb.getString(1, control.topicID) < 0:
trace "topic field missing from ihave msg"
break
trace "read topic field", topicID = control.topicID
while true:
var mid: string
if pb.getString(2, mid) < 0:
break
trace "read messageID field", mid = mid
control.messageIDs.add(mid)
result.add(control)
proc encodeIWant*(iwant: ControlIWant, pb: var ProtoBuffer) {.gcsafe.} =
for mid in iwant.messageIDs: for mid in iwant.messageIDs:
pb.write(initProtoField(1, mid)) ipb.write(1, mid)
if len(ipb.buffer) > 0:
ipb.finish()
pb.write(field, ipb)
proc decodeIWant*(pb: var ProtoBuffer): seq[ControlIWant] {.gcsafe.} = proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) =
trace "decoding iwant msg" var ipb = initProtoBuffer()
for ihave in control.ihave:
ipb.write(1, ihave)
for iwant in control.iwant:
ipb.write(2, iwant)
for graft in control.graft:
ipb.write(3, graft)
for prune in control.prune:
ipb.write(4, prune)
if len(ipb.buffer) > 0:
ipb.finish()
pb.write(field, ipb)
var control: ControlIWant proc write*(pb: var ProtoBuffer, field: int, subs: SubOpts) =
while true: var ipb = initProtoBuffer()
var mid: string ipb.write(1, uint64(subs.subscribe))
if pb.getString(1, mid) < 0: ipb.write(2, subs.topic)
break ipb.finish()
control.messageIDs.add(mid) pb.write(field, ipb)
trace "read messageID field", mid = mid
result.add(control)
proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} =
if control.ihave.len > 0:
var ihave = initProtoBuffer()
for h in control.ihave:
h.encodeIHave(ihave)
# write messages to protobuf
if ihave.buffer.len > 0:
ihave.finish()
pb.write(initProtoField(1, ihave))
if control.iwant.len > 0:
var iwant = initProtoBuffer()
for w in control.iwant:
w.encodeIWant(iwant)
# write messages to protobuf
if iwant.buffer.len > 0:
iwant.finish()
pb.write(initProtoField(2, iwant))
if control.graft.len > 0:
var graft = initProtoBuffer()
for g in control.graft:
g.encodeGraft(graft)
# write messages to protobuf
if graft.buffer.len > 0:
graft.finish()
pb.write(initProtoField(3, graft))
if control.prune.len > 0:
var prune = initProtoBuffer()
for p in control.prune:
p.encodePrune(prune)
# write messages to protobuf
if prune.buffer.len > 0:
prune.finish()
pb.write(initProtoField(4, prune))
proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} =
trace "decoding control submessage"
var control: ControlMessage
while true:
var field = pb.enterSubMessage()
trace "processing submessage", field = field
case field:
of 0:
trace "no submessage found in Control msg"
break
of 1:
control.ihave &= pb.decodeIHave()
of 2:
control.iwant &= pb.decodeIWant()
of 3:
control.graft &= pb.decodeGraft()
of 4:
control.prune &= pb.decodePrune()
else:
raise newException(CatchableError, "message type not recognized")
if result.isNone:
result = some(control)
proc encodeSubs*(subs: SubOpts, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, subs.subscribe))
pb.write(initProtoField(2, subs.topic))
proc decodeSubs*(pb: var ProtoBuffer): seq[SubOpts] {.gcsafe.} =
while true:
var subOpt: SubOpts
var subscr: uint
discard pb.getVarintValue(1, subscr)
subOpt.subscribe = cast[bool](subscr)
trace "read subscribe field", subscribe = subOpt.subscribe
if pb.getString(2, subOpt.topic) < 0:
break
trace "read subscribe field", topicName = subOpt.topic
result.add(subOpt)
trace "got subscriptions", subscriptions = result
proc encodeMessage*(msg: Message, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, msg.fromPeer.getBytes()))
pb.write(initProtoField(2, msg.data))
pb.write(initProtoField(3, msg.seqno))
for t in msg.topicIDs:
pb.write(initProtoField(4, t))
if msg.signature.len > 0:
pb.write(initProtoField(5, msg.signature))
if msg.key.len > 0:
pb.write(initProtoField(6, msg.key))
proc encodeMessage*(msg: Message): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, msg.fromPeer)
pb.write(2, msg.data)
pb.write(3, msg.seqno)
for topic in msg.topicIDs:
pb.write(4, topic)
if len(msg.signature) > 0:
pb.write(5, msg.signature)
if len(msg.key) > 0:
pb.write(6, msg.key)
pb.finish() pb.finish()
pb.buffer
proc decodeMessages*(pb: var ProtoBuffer): seq[Message] {.gcsafe.} = proc write*(pb: var ProtoBuffer, field: int, msg: Message) =
# TODO: which of this fields are really optional? pb.write(field, encodeMessage(msg))
while true:
var msg: Message
var fromPeer: seq[byte]
if pb.getBytes(1, fromPeer) < 0:
break
try:
msg.fromPeer = PeerID.init(fromPeer).tryGet()
except CatchableError as err:
debug "Invalid fromPeer in message", msg = err.msg
break
trace "read message field", fromPeer = msg.fromPeer.pretty proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} =
trace "decodeGraft: decoding message"
var control = ControlGraft()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodeGraft: read topicId", topic_id = topicId
else:
trace "decodeGraft: topicId is missing"
control
if pb.getBytes(2, msg.data) < 0: proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} =
break trace "decodePrune: decoding message"
trace "read message field", data = msg.data.shortLog var control = ControlPrune()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodePrune: read topicId", topic_id = topicId
else:
trace "decodePrune: topicId is missing"
control
if pb.getBytes(3, msg.seqno) < 0: proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} =
break trace "decodeIHave: decoding message"
trace "read message field", seqno = msg.seqno.shortLog var control = ControlIHave()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodeIHave: read topicId", topic_id = topicId
else:
trace "decodeIHave: topicId is missing"
if pb.getRepeatedField(2, control.messageIDs):
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
else:
trace "decodeIHave: no messageIDs"
control
var topic: string proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} =
while true: trace "decodeIWant: decoding message"
if pb.getString(4, topic) < 0: var control = ControlIWant()
break if pb.getRepeatedField(1, control.messageIDs):
msg.topicIDs.add(topic) trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
trace "read message field", topicName = topic else:
topic = "" trace "decodeIWant: no messageIDs"
discard pb.getBytes(5, msg.signature) proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} =
trace "read message field", signature = msg.signature.shortLog trace "decodeControl: decoding message"
var buffer: seq[byte]
if pb.getField(3, buffer):
var control: ControlMessage
var cpb = initProtoBuffer(buffer)
var ihavepbs: seq[seq[byte]]
var iwantpbs: seq[seq[byte]]
var graftpbs: seq[seq[byte]]
var prunepbs: seq[seq[byte]]
discard pb.getBytes(6, msg.key) discard cpb.getRepeatedField(1, ihavepbs)
trace "read message field", key = msg.key.shortLog discard cpb.getRepeatedField(2, iwantpbs)
discard cpb.getRepeatedField(3, graftpbs)
discard cpb.getRepeatedField(4, prunepbs)
result.add(msg) for item in ihavepbs:
control.ihave.add(decodeIHave(initProtoBuffer(item)))
for item in iwantpbs:
control.iwant.add(decodeIWant(initProtoBuffer(item)))
for item in graftpbs:
control.graft.add(decodeGraft(initProtoBuffer(item)))
for item in prunepbs:
control.prune.add(decodePrune(initProtoBuffer(item)))
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = trace "decodeControl: "
result = initProtoBuffer() some(control)
trace "encoding msg: ", msg = msg.shortLog else:
none[ControlMessage]()
if msg.subscriptions.len > 0: proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} =
for s in msg.subscriptions: trace "decodeSubscription: decoding message"
var subs = initProtoBuffer() var subflag: uint64
encodeSubs(s, subs) var sub = SubOpts()
if pb.getField(1, subflag):
sub.subscribe = bool(subflag)
trace "decodeSubscription: read subscribe", subscribe = subflag
else:
trace "decodeSubscription: subscribe is missing"
if pb.getField(2, sub.topic):
trace "decodeSubscription: read topic", topic = sub.topic
else:
trace "decodeSubscription: topic is missing"
# write subscriptions to protobuf sub
if subs.buffer.len > 0:
subs.finish()
result.write(initProtoField(1, subs))
if msg.messages.len > 0: proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} =
var messages = initProtoBuffer() trace "decodeSubscriptions: decoding message"
for m in msg.messages: var subpbs: seq[seq[byte]]
encodeMessage(m, messages) var subs: seq[SubOpts]
if pb.getRepeatedField(1, subpbs):
trace "decodeSubscriptions: read subscriptions", count = len(subpbs)
for item in subpbs:
let sub = decodeSubscription(initProtoBuffer(item))
subs.add(sub)
# write messages to protobuf if len(subs) == 0:
if messages.buffer.len > 0: trace "decodeSubscription: no subscriptions found"
messages.finish()
result.write(initProtoField(2, messages))
if msg.control.isSome: subs
var control = initProtoBuffer()
msg.control.get.encodeControl(control)
# write messages to protobuf proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} =
if control.buffer.len > 0: trace "decodeMessage: decoding message"
control.finish() var msg: Message
result.write(initProtoField(3, control)) if pb.getField(1, msg.fromPeer):
trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty()
else:
trace "decodeMessage: fromPeer is missing"
if result.buffer.len > 0: if pb.getField(2, msg.data):
result.finish() trace "decodeMessage: read data", data = msg.data.shortLog()
else:
trace "decodeMessage: data is missing"
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} = if pb.getField(3, msg.seqno):
trace "decodeMessage: read seqno", seqno = msg.data.shortLog()
else:
trace "decodeMessage: seqno is missing"
if pb.getRepeatedField(4, msg.topicIDs):
trace "decodeMessage: read topics", topic_ids = msg.topicIDs
else:
trace "decodeMessage: topics are missing"
if pb.getField(5, msg.signature):
trace "decodeMessage: read signature", signature = msg.signature.shortLog()
else:
trace "decodeMessage: signature is missing"
if pb.getField(6, msg.key):
trace "decodeMessage: read public key", key = msg.key.shortLog()
else:
trace "decodeMessage: public key is missing"
msg
proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} =
trace "decodeMessages: decoding message"
var msgpbs: seq[seq[byte]]
var msgs: seq[Message]
if pb.getRepeatedField(2, msgpbs):
trace "decodeMessages: read messages", count = len(msgpbs)
for item in msgpbs:
let msg = decodeMessage(initProtoBuffer(item))
msgs.add(msg)
if len(msgs) == 0:
trace "decodeMessages: no messages found"
msgs
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
trace "encodeRpcMsg: encoding message", msg = msg.shortLog()
var pb = initProtoBuffer()
for item in msg.subscriptions:
pb.write(1, item)
for item in msg.messages:
pb.write(2, item)
if msg.control.isSome():
pb.write(3, msg.control.get())
if len(pb.buffer) > 0:
pb.finish()
result = pb
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg =
trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
var pb = initProtoBuffer(msg) var pb = initProtoBuffer(msg)
var rpcMsg: RPCMsg
rpcMsg.messages = pb.decodeMessages()
rpcMsg.subscriptions = pb.decodeSubscriptions()
rpcMsg.control = pb.decodeControl()
while true: rpcMsg
# decode SubOpts array
var field = pb.enterSubMessage()
trace "processing submessage", field = field
case field:
of 0:
trace "no submessage found in RPC msg"
break
of 1:
result.subscriptions &= pb.decodeSubs()
of 2:
result.messages &= pb.decodeMessages()
of 3:
result.control = pb.decodeControl()
else:
raise newException(CatchableError, "message type not recognized")

View File

@ -430,8 +430,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
var var
libp2pProof = initProtoBuffer() libp2pProof = initProtoBuffer()
libp2pProof.write(initProtoField(1, p.localPublicKey)) libp2pProof.write(1, p.localPublicKey)
libp2pProof.write(initProtoField(2, signedPayload.getBytes())) libp2pProof.write(2, signedPayload.getBytes())
# data field also there but not used! # data field also there but not used!
libp2pProof.finish() libp2pProof.finish()
@ -449,9 +449,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
remoteSig: Signature remoteSig: Signature
remoteSigBytes: seq[byte] remoteSigBytes: seq[byte]
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: if not(remoteProof.getField(1, remotePubKeyBytes)):
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: if not(remoteProof.getField(2, remoteSigBytes)):
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remotePubKey.init(remotePubKeyBytes): if not remotePubKey.init(remotePubKeyBytes):

677
tests/testminprotobuf.nim Normal file
View File

@ -0,0 +1,677 @@
## Nim-Libp2p
## Copyright (c) 2018 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import unittest
import ../libp2p/protobuf/minprotobuf
import stew/byteutils, strutils
when defined(nimHasUsed): {.used.}
suite "MinProtobuf test suite":
const VarintVectors = [
"0800", "0801", "08ffffffff07", "08ffffffff0f", "08ffffffffffffffff7f",
"08ffffffffffffffffff01"
]
const VarintValues = [
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xFFFF_FFFF'u64,
0x7FFF_FFFF_FFFF_FFFF'u64, 0xFFFF_FFFF_FFFF_FFFF'u64
]
const Fixed32Vectors = [
"0d00000000", "0d01000000", "0dffffff7f", "0dddccbbaa", "0dffffffff"
]
const Fixed32Values = [
0x0'u32, 0x1'u32, 0x7FFF_FFFF'u32, 0xAABB_CCDD'u32, 0xFFFF_FFFF'u32
]
const Fixed64Vectors = [
"090000000000000000", "090100000000000000", "09ffffff7f00000000",
"09ddccbbaa00000000", "09ffffffff00000000", "09ffffffffffffff7f",
"099988ffeeddccbbaa", "09ffffffffffffffff"
]
const Fixed64Values = [
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xAABB_CCDD'u64, 0xFFFF_FFFF'u64,
0x7FFF_FFFF_FFFF_FFFF'u64, 0xAABB_CCDD_EEFF_8899'u64,
0xFFFF_FFFF_FFFF_FFFF'u64
]
const LengthVectors = [
"0a00", "0a0161", "0a026162", "0a0461626364", "0a086162636465666768"
]
const LengthValues = [
"", "a", "ab", "abcd", "abcdefgh"
]
## This vector values was tested with `protoc` and related proto file.
## syntax = "proto2";
## message testmsg {
## repeated uint64 d = 1 [packed=true];
## repeated uint64 d = 2 [packed=true];
## }
const PackedVarintVector =
"0a1f0001ffffffff07ffffffff0fffffffffffffffff7fffffffffffffffffff0112020001"
## syntax = "proto2";
## message testmsg {
## repeated sfixed32 d = 1 [packed=true];
## repeated sfixed32 d = 2 [packed=true];
## }
const PackedFixed32Vector =
"0a140000000001000000ffffff7fddccbbaaffffffff12080000000001000000"
## syntax = "proto2";
## message testmsg {
## repeated sfixed64 d = 1 [packed=true];
## repeated sfixed64 d = 2 [packed=true];
## }
const PackedFixed64Vector =
"""0a4000000000000000000100000000000000ffffff7f00000000ddccbbaa00000000
ffffffff00000000ffffffffffffff7f9988ffeeddccbbaaffffffffffffffff1210
00000000000000000100000000000000"""
proc getVarintEncodedValue(value: uint64): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getVarintDecodedValue(data: openarray[byte]): uint64 =
var value: uint64
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
value
proc getFixed32EncodedValue(value: float32): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getFixed32DecodedValue(data: openarray[byte]): uint32 =
var value: float32
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
cast[uint32](value)
proc getFixed64EncodedValue(value: float64): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getFixed64DecodedValue(data: openarray[byte]): uint64 =
var value: float64
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
cast[uint64](value)
proc getLengthEncodedValue(value: string): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getLengthEncodedValue(value: seq[byte]): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getLengthDecodedValue(data: openarray[byte]): string =
var value = newString(len(data))
var valueLen = 0
var pb = initProtoBuffer(data)
let res = pb.getField(1, value, valueLen)
doAssert(res)
value.setLen(valueLen)
value
proc isFullZero[T: byte|char](data: openarray[T]): bool =
for ch in data:
if int(ch) != 0:
return false
return true
proc corruptHeader(data: var openarray[byte], index: int) =
var values = [3, 4, 6]
data[0] = data[0] and 0xF8'u8
data[0] = data[0] or byte(values[index mod len(values)])
test "[varint] edge values test":
for i in 0 ..< len(VarintValues):
let data = getVarintEncodedValue(VarintValues[i])
check:
toHex(data) == VarintVectors[i]
getVarintDecodedValue(data) == VarintValues[i]
test "[varint] mixing many values with same field number test":
for i in 0 ..< len(VarintValues):
var pb = initProtoBuffer()
for k in 0 ..< len(VarintValues):
let index = (i + k + 1) mod len(VarintValues)
pb.write(1, VarintValues[index])
pb.finish()
check getVarintDecodedValue(pb.buffer) == VarintValues[i]
test "[varint] incorrect values test":
for i in 0 ..< len(VarintValues):
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[varint] non-existent field test":
for i in 0 ..< len(VarintValues):
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == 0'u64
test "[varint] corrupted header test":
for i in 0 ..< len(VarintValues):
for k in 0 ..< 3:
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[varint] empty buffer test":
var value: uint64
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == 0'u64
test "[varint] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, VarintValues[1])
pb1.write(1, VarintValues[2])
pb1.write(2, VarintValues[3])
pb1.write(1, VarintValues[4])
pb1.write(1, VarintValues[5])
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[uint64]
var fieldarr2: seq[uint64]
var fieldarr3: seq[uint64]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
fieldarr1[0] == VarintValues[1]
fieldarr1[1] == VarintValues[2]
fieldarr1[2] == VarintValues[4]
fieldarr1[3] == VarintValues[5]
fieldarr2[0] == VarintValues[3]
test "[varint] Repeated packed field test":
var pb1 = initProtoBuffer()
pb1.writePacked(1, VarintValues)
pb1.writePacked(2, VarintValues[0 .. 1])
pb1.finish()
check:
toHex(pb1.buffer) == PackedVarintVector
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[uint64]
var fieldarr2: seq[uint64]
var fieldarr3: seq[uint64]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 6
fieldarr1[0] == VarintValues[0]
fieldarr1[1] == VarintValues[1]
fieldarr1[2] == VarintValues[2]
fieldarr1[3] == VarintValues[3]
fieldarr1[4] == VarintValues[4]
fieldarr1[5] == VarintValues[5]
fieldarr2[0] == VarintValues[0]
fieldarr2[1] == VarintValues[1]
test "[fixed32] edge values test":
for i in 0 ..< len(Fixed32Values):
let data = getFixed32EncodedValue(cast[float32](Fixed32Values[i]))
check:
toHex(data) == Fixed32Vectors[i]
getFixed32DecodedValue(data) == Fixed32Values[i]
test "[fixed32] mixing many values with same field number test":
for i in 0 ..< len(Fixed32Values):
var pb = initProtoBuffer()
for k in 0 ..< len(Fixed32Values):
let index = (i + k + 1) mod len(Fixed32Values)
pb.write(1, cast[float32](Fixed32Values[index]))
pb.finish()
check getFixed32DecodedValue(pb.buffer) == Fixed32Values[i]
test "[fixed32] incorrect values test":
for i in 0 ..< len(Fixed32Values):
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed32] non-existent field test":
for i in 0 ..< len(Fixed32Values):
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == float32(0)
test "[fixed32] corrupted header test":
for i in 0 ..< len(Fixed32Values):
for k in 0 ..< 3:
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed32] empty buffer test":
var value: float32
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == float32(0)
test "[fixed32] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, cast[float32](Fixed32Values[0]))
pb1.write(1, cast[float32](Fixed32Values[1]))
pb1.write(2, cast[float32](Fixed32Values[2]))
pb1.write(1, cast[float32](Fixed32Values[3]))
pb1.write(1, cast[float32](Fixed32Values[4]))
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float32]
var fieldarr2: seq[float32]
var fieldarr3: seq[float32]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[uint32](fieldarr1[0]) == Fixed64Values[0]
cast[uint32](fieldarr1[1]) == Fixed64Values[1]
cast[uint32](fieldarr1[2]) == Fixed64Values[3]
cast[uint32](fieldarr1[3]) == Fixed64Values[4]
cast[uint32](fieldarr2[0]) == Fixed64Values[2]
test "[fixed32] Repeated packed field test":
var pb1 = initProtoBuffer()
var values = newSeq[float32](len(Fixed32Values))
for i in 0 ..< len(values):
values[i] = cast[float32](Fixed32Values[i])
pb1.writePacked(1, values)
pb1.writePacked(2, values[0 .. 1])
pb1.finish()
check:
toHex(pb1.buffer) == PackedFixed32Vector
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float32]
var fieldarr2: seq[float32]
var fieldarr3: seq[float32]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 5
cast[uint32](fieldarr1[0]) == Fixed32Values[0]
cast[uint32](fieldarr1[1]) == Fixed32Values[1]
cast[uint32](fieldarr1[2]) == Fixed32Values[2]
cast[uint32](fieldarr1[3]) == Fixed32Values[3]
cast[uint32](fieldarr1[4]) == Fixed32Values[4]
cast[uint32](fieldarr2[0]) == Fixed32Values[0]
cast[uint32](fieldarr2[1]) == Fixed32Values[1]
test "[fixed64] edge values test":
for i in 0 ..< len(Fixed64Values):
let data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
check:
toHex(data) == Fixed64Vectors[i]
getFixed64DecodedValue(data) == Fixed64Values[i]
test "[fixed64] mixing many values with same field number test":
for i in 0 ..< len(Fixed64Values):
var pb = initProtoBuffer()
for k in 0 ..< len(Fixed64Values):
let index = (i + k + 1) mod len(Fixed64Values)
pb.write(1, cast[float64](Fixed64Values[index]))
pb.finish()
check getFixed64DecodedValue(pb.buffer) == Fixed64Values[i]
test "[fixed64] incorrect values test":
for i in 0 ..< len(Fixed64Values):
var value: float32
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed64] non-existent field test":
for i in 0 ..< len(Fixed64Values):
var value: float64
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == float64(0)
test "[fixed64] corrupted header test":
for i in 0 ..< len(Fixed64Values):
for k in 0 ..< 3:
var value: float64
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed64] empty buffer test":
var value: float64
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == float64(0)
test "[fixed64] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, cast[float64](Fixed64Values[2]))
pb1.write(1, cast[float64](Fixed64Values[3]))
pb1.write(2, cast[float64](Fixed64Values[4]))
pb1.write(1, cast[float64](Fixed64Values[5]))
pb1.write(1, cast[float64](Fixed64Values[6]))
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float64]
var fieldarr2: seq[float64]
var fieldarr3: seq[float64]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[uint64](fieldarr1[0]) == Fixed64Values[2]
cast[uint64](fieldarr1[1]) == Fixed64Values[3]
cast[uint64](fieldarr1[2]) == Fixed64Values[5]
cast[uint64](fieldarr1[3]) == Fixed64Values[6]
cast[uint64](fieldarr2[0]) == Fixed64Values[4]
test "[fixed64] Repeated packed field test":
var pb1 = initProtoBuffer()
var values = newSeq[float64](len(Fixed64Values))
for i in 0 ..< len(values):
values[i] = cast[float64](Fixed64Values[i])
pb1.writePacked(1, values)
pb1.writePacked(2, values[0 .. 1])
pb1.finish()
let expect = PackedFixed64Vector.multiReplace(("\n", ""), (" ", ""))
check:
toHex(pb1.buffer) == expect
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float64]
var fieldarr2: seq[float64]
var fieldarr3: seq[float64]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 8
cast[uint64](fieldarr1[0]) == Fixed64Values[0]
cast[uint64](fieldarr1[1]) == Fixed64Values[1]
cast[uint64](fieldarr1[2]) == Fixed64Values[2]
cast[uint64](fieldarr1[3]) == Fixed64Values[3]
cast[uint64](fieldarr1[4]) == Fixed64Values[4]
cast[uint64](fieldarr1[5]) == Fixed64Values[5]
cast[uint64](fieldarr1[6]) == Fixed64Values[6]
cast[uint64](fieldarr1[7]) == Fixed64Values[7]
cast[uint64](fieldarr2[0]) == Fixed64Values[0]
cast[uint64](fieldarr2[1]) == Fixed64Values[1]
test "[length] edge values test":
for i in 0 ..< len(LengthValues):
let data1 = getLengthEncodedValue(LengthValues[i])
let data2 = getLengthEncodedValue(cast[seq[byte]](LengthValues[i]))
check:
toHex(data1) == LengthVectors[i]
toHex(data2) == LengthVectors[i]
check:
getLengthDecodedValue(data1) == LengthValues[i]
getLengthDecodedValue(data2) == LengthValues[i]
test "[length] mixing many values with same field number test":
for i in 0 ..< len(LengthValues):
var pb1 = initProtoBuffer()
var pb2 = initProtoBuffer()
for k in 0 ..< len(LengthValues):
let index = (i + k + 1) mod len(LengthValues)
pb1.write(1, LengthValues[index])
pb2.write(1, cast[seq[byte]](LengthValues[index]))
pb1.finish()
pb2.finish()
check getLengthDecodedValue(pb1.buffer) == LengthValues[i]
check getLengthDecodedValue(pb2.buffer) == LengthValues[i]
test "[length] incorrect values test":
for i in 0 ..< len(LengthValues):
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
test "[length] non-existent field test":
for i in 0 ..< len(LengthValues):
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
var pb = initProtoBuffer(data)
check:
pb.getField(2, value, valueLen) == false
valueLen == 0
test "[length] corrupted header test":
for i in 0 ..< len(LengthValues):
for k in 0 ..< 3:
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
test "[length] empty buffer test":
var value = newSeq[byte](len(LengthValues[0]))
var valueLen = 0
var pb = initProtoBuffer()
check:
pb.getField(1, value, valueLen) == false
valueLen == 0
test "[length] buffer overflow test":
for i in 1 ..< len(LengthValues):
let data = getLengthEncodedValue(LengthValues[i])
var value = newString(len(LengthValues[i]) - 1)
var valueLen = 0
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
valueLen == len(LengthValues[i])
isFullZero(value) == true
test "[length] mix of buffer overflow and normal fields test":
var pb1 = initProtoBuffer()
pb1.write(1, "TEST10")
pb1.write(1, "TEST20")
pb1.write(1, "TEST")
pb1.write(1, "TEST30")
pb1.write(1, "SOME")
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(4)
var valueLen = 0
check:
pb2.getField(1, value, valueLen) == true
value == "SOME"
test "[length] too big message test":
var pb1 = initProtoBuffer()
var bigString = newString(MaxMessageSize + 1)
for i in 0 ..< len(bigString):
bigString[i] = 'A'
pb1.write(1, bigString)
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(MaxMessageSize + 1)
var valueLen = 0
check:
pb2.getField(1, value, valueLen) == false
test "[length] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, "TEST1")
pb1.write(1, "TEST2")
pb1.write(2, "TEST5")
pb1.write(1, "TEST3")
pb1.write(1, "TEST4")
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[seq[byte]]
var fieldarr2: seq[seq[byte]]
var fieldarr3: seq[seq[byte]]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[string](fieldarr1[0]) == "TEST1"
cast[string](fieldarr1[1]) == "TEST2"
cast[string](fieldarr1[2]) == "TEST3"
cast[string](fieldarr1[3]) == "TEST4"
cast[string](fieldarr2[0]) == "TEST5"
test "Different value types in one message with same field number test":
proc getEncodedValue(): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, VarintValues[1])
pb.write(2, cast[float32](Fixed32Values[1]))
pb.write(3, cast[float64](Fixed64Values[1]))
pb.write(4, LengthValues[1])
pb.write(1, VarintValues[2])
pb.write(2, cast[float32](Fixed32Values[2]))
pb.write(3, cast[float64](Fixed64Values[2]))
pb.write(4, LengthValues[2])
pb.write(1, cast[float32](Fixed32Values[3]))
pb.write(2, cast[float64](Fixed64Values[3]))
pb.write(3, LengthValues[3])
pb.write(4, VarintValues[3])
pb.write(1, cast[float64](Fixed64Values[4]))
pb.write(2, LengthValues[4])
pb.write(3, VarintValues[4])
pb.write(4, cast[float32](Fixed32Values[4]))
pb.write(1, VarintValues[1])
pb.write(2, cast[float32](Fixed32Values[1]))
pb.write(3, cast[float64](Fixed64Values[1]))
pb.write(4, LengthValues[1])
pb.finish()
pb.buffer
let msg = getEncodedValue()
let pb = initProtoBuffer(msg)
var varintValue: uint64
var fixed32Value: float32
var fixed64Value: float64
var lengthValue = newString(10)
var lengthSize: int
check:
pb.getField(1, varintValue) == true
pb.getField(2, fixed32Value) == true
pb.getField(3, fixed64Value) == true
pb.getField(4, lengthValue, lengthSize) == true
lengthValue.setLen(lengthSize)
check:
varintValue == VarintValues[1]
cast[uint32](fixed32Value) == Fixed32Values[1]
cast[uint64](fixed64Value) == Fixed64Values[1]
lengthValue == LengthValues[1]

View File

@ -1,4 +1,5 @@
import testvarint, import testvarint,
testminprotobuf,
teststreamseq teststreamseq
import testrsa, import testrsa,