[WIP] Minprotobuf refactoring (#259)

* Minprotobuf initial commit

* Fix noise.

* Add signed integers support.
Add checks for field number value.
Remove some casts.

* Fix compile errors.

* Fix comments and constants.
This commit is contained in:
Eugene Kabanov 2020-07-13 15:43:07 +03:00 committed by GitHub
parent 181cf73ca7
commit efb952f18b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1705 additions and 413 deletions

View File

@ -222,8 +222,8 @@ proc toBytes*(key: PrivateKey, data: var openarray[byte]): CryptoResult[int] =
##
## Returns number of bytes (octets) needed to store private key ``key``.
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
var blen = len(msg.buffer)
if len(data) >= blen:
@ -236,8 +236,8 @@ proc toBytes*(key: PublicKey, data: var openarray[byte]): CryptoResult[int] =
##
## Returns number of bytes (octets) needed to store public key ``key``.
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
var blen = len(msg.buffer)
if len(data) >= blen and blen > 0:
@ -256,8 +256,8 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] =
## Return private key ``key`` in binary form (using libp2p's protobuf
## serialization).
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
ok(msg.buffer)
@ -265,8 +265,8 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] =
## Return public key ``key`` in binary form (using libp2p's protobuf
## serialization).
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
ok(msg.buffer)
@ -283,33 +283,32 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool =
var buffer: seq[byte]
if len(data) > 0:
var pb = initProtoBuffer(@data)
if pb.getVarintValue(1, id) != 0:
if pb.getBytes(2, buffer) != 0:
if cast[int8](id) in SupportedSchemesInt:
var scheme = cast[PKScheme](cast[int8](id))
when key is PrivateKey:
var nkey = PrivateKey(scheme: scheme)
else:
var nkey = PublicKey(scheme: scheme)
case scheme:
of PKScheme.RSA:
if init(nkey.rsakey, buffer).isOk:
key = nkey
return true
of PKScheme.Ed25519:
if init(nkey.edkey, buffer):
key = nkey
return true
of PKScheme.ECDSA:
if init(nkey.eckey, buffer).isOk:
key = nkey
return true
of PKScheme.Secp256k1:
if init(nkey.skkey, buffer).isOk:
key = nkey
return true
else:
return false
if pb.getField(1, id) and pb.getField(2, buffer):
if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0:
var scheme = cast[PKScheme](cast[int8](id))
when key is PrivateKey:
var nkey = PrivateKey(scheme: scheme)
else:
var nkey = PublicKey(scheme: scheme)
case scheme:
of PKScheme.RSA:
if init(nkey.rsakey, buffer).isOk:
key = nkey
return true
of PKScheme.Ed25519:
if init(nkey.edkey, buffer):
key = nkey
return true
of PKScheme.ECDSA:
if init(nkey.eckey, buffer).isOk:
key = nkey
return true
of PKScheme.Secp256k1:
if init(nkey.skkey, buffer).isOk:
key = nkey
return true
else:
return false
proc init*(sig: var Signature, data: openarray[byte]): bool =
## Initialize signature ``sig`` from raw binary form.
@ -727,11 +726,11 @@ proc createProposal*(nonce, pubkey: openarray[byte],
## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and
## comma-delimeted list of supported hashes ``hashes``.
var msg = initProtoBuffer({WithUint32BeLength})
msg.write(initProtoField(1, nonce))
msg.write(initProtoField(2, pubkey))
msg.write(initProtoField(3, exchanges))
msg.write(initProtoField(4, ciphers))
msg.write(initProtoField(5, hashes))
msg.write(1, nonce)
msg.write(2, pubkey)
msg.write(3, exchanges)
msg.write(4, ciphers)
msg.write(5, hashes)
msg.finish()
shallowCopy(result, msg.buffer)
@ -744,19 +743,16 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
##
## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message)
if pb.getLengthValue(1, nonce) != -1 and
pb.getLengthValue(2, pubkey) != -1 and
pb.getLengthValue(3, exchanges) != -1 and
pb.getLengthValue(4, ciphers) != -1 and
pb.getLengthValue(5, hashes) != -1:
result = true
pb.getField(1, nonce) and pb.getField(2, pubkey) and
pb.getField(3, exchanges) and pb.getField(4, ciphers) and
pb.getField(5, hashes)
proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] =
## Create SecIO exchange message using ephemeral public key ``epubkey`` and
## signature of proposal blocks ``signature``.
var msg = initProtoBuffer({WithUint32BeLength})
msg.write(initProtoField(1, epubkey))
msg.write(initProtoField(2, signature))
msg.write(1, epubkey)
msg.write(2, signature)
msg.finish()
shallowCopy(result, msg.buffer)
@ -767,9 +763,7 @@ proc decodeExchange*(message: seq[byte],
##
## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message)
if pb.getLengthValue(1, pubkey) != -1 and
pb.getLengthValue(2, signature) != -1:
result = true
pb.getField(1, pubkey) and pb.getField(2, signature)
## Serialization/Deserialization helpers
@ -788,22 +782,27 @@ proc write*(vb: var VBuffer, sig: PrivateKey) {.
## Write Signature value ``sig`` to buffer ``vb``.
vb.writeSeq(sig.getBytes().tryGet())
proc initProtoField*(index: int, pubkey: PublicKey): ProtoField {.
raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PublicKey ``pubkey``.
result = initProtoField(index, pubkey.getBytes().tryGet())
proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int,
key: T) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
write(pb, field, key.getBytes().tryGet())
proc initProtoField*(index: int, seckey: PrivateKey): ProtoField {.
raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PrivateKey ``seckey``.
result = initProtoField(index, seckey.getBytes().tryGet())
proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
write(pb, field, sig.getBytes())
proc initProtoField*(index: int, sig: Signature): ProtoField =
proc initProtoField*(index: int, key: PublicKey|PrivateKey): ProtoField {.
deprecated, raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PublicKey/PrivateKey ``key``.
result = initProtoField(index, key.getBytes().tryGet())
proc initProtoField*(index: int, sig: Signature): ProtoField {.deprecated.} =
## Initialize ProtoField with Signature ``sig``.
result = initProtoField(index, sig.getBytes())
proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
## Read ``PublicKey`` from ProtoBuf's message and validate it.
proc getValue*[T: PublicKey|PrivateKey](data: var ProtoBuffer, field: int,
value: var T): int {.deprecated.} =
## Read PublicKey/PrivateKey from ProtoBuf's message and validate it.
var buf: seq[byte]
var key: PublicKey
result = getLengthValue(data, field, buf)
@ -813,18 +812,8 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
else:
value = key
proc getValue*(data: var ProtoBuffer, field: int, value: var PrivateKey): int =
## Read ``PrivateKey`` from ProtoBuf's message and validate it.
var buf: seq[byte]
var key: PrivateKey
result = getLengthValue(data, field, buf)
if result > 0:
if not key.init(buf):
result = -1
else:
value = key
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {.
deprecated.} =
## Read ``Signature`` from ProtoBuf's message and validate it.
var buf: seq[byte]
var sig: Signature
@ -834,3 +823,30 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
result = -1
else:
value = sig
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
value: var T): bool =
var buffer: seq[byte]
var key: T
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if key.init(buffer):
value = key
true
else:
false
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool =
var buffer: seq[byte]
var sig: Signature
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if sig.init(buffer):
value = sig
true
else:
false

View File

@ -14,9 +14,10 @@
import nativesockets
import tables, strutils, stew/shims/net
import chronos
import multicodec, multihash, multibase, transcoder, vbuffer, peerid
import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
protobuf/minprotobuf
import stew/[base58, base32, endians2, results]
export results
export results, minprotobuf, vbuffer
type
MAKind* = enum
@ -477,7 +478,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] =
else:
ok($(proto.mcodec))
proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] =
proc protoArgument*(ma: MultiAddress,
value: var openarray[byte]): MaResult[int] =
## Returns MultiAddress ``ma`` protocol argument value.
##
## If current MultiAddress do not have argument value, then result will be
@ -496,8 +498,8 @@ proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int]
var res: int
if proto.kind == Fixed:
res = proto.size
if len(value) >= res and
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
if len(value) >= res and
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
err("multiaddress: Decoding protocol error")
else:
ok(res)
@ -580,7 +582,8 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] =
let proto = CodeAddresses.getOrDefault(MultiCodec(header))
if proto.kind == None:
yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'")
yield err(MaResult[MultiAddress], "Unsupported protocol '" &
$header & "'")
elif proto.kind == Fixed:
data.setLen(proto.size)
@ -609,7 +612,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} =
return ok(true)
ok(false)
proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
proc `[]`*(ma: MultiAddress,
codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
## Returns partial MultiAddress with MultiCodec ``codec`` and present in
## MultiAddress ``ma``.
for item in ma.items:
@ -634,7 +638,8 @@ proc toString*(value: MultiAddress): MaResult[string] =
return err("multiaddress: Unsupported protocol '" & $header & "'")
if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.bufferToString):
return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder")
return err("multiaddress: Missing protocol '" & $(proto.mcodec) &
"' coder")
if not proto.coder.bufferToString(vb.data, part):
return err("multiaddress: Decoding protocol error")
parts.add($(proto.mcodec))
@ -729,12 +734,14 @@ proc init*(
of None:
raiseAssert "None checked above"
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerID): MaResult[MultiAddress] {.inline.} =
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: PeerID): MaResult[MultiAddress] {.inline.} =
## Initialize MultiAddress object from protocol id ``protocol`` and peer id
## ``value``.
init(mtype, protocol, value.data)
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: int): MaResult[MultiAddress] =
## Initialize MultiAddress object from protocol id ``protocol`` and integer
## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``,
## ``dccp`` and ``sctp`` MultiAddresses.
@ -759,7 +766,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} =
if mc != InvalidMultiCodec:
result = CodeAddresses.getOrDefault(mc)
proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress],
value: string): MaResult[MultiAddress] =
## Initialize MultiAddress object from string representation ``value``.
var parts = value.trimRight('/').split('/')
if len(parts[0]) != 0:
@ -776,7 +784,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
else:
if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.stringToBuffer):
return err("multiaddress: Missing protocol '" & part & "' transcoder")
return err("multiaddress: Missing protocol '" &
part & "' transcoder")
if offset + 1 >= len(parts):
return err("multiaddress: Missing protocol '" & part & "' argument")
@ -785,14 +794,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.write(proto.mcodec)
let res = proto.coder.stringToBuffer(parts[offset + 1], res.data)
if not res:
return err("multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`")
return err("multiaddress: Error encoding `" & part & "/" &
parts[offset + 1] & "`")
offset += 2
elif proto.kind == Path:
var path = "/" & (parts[(offset + 1)..^1].join("/"))
res.data.write(proto.mcodec)
if not proto.coder.stringToBuffer(path, res.data):
return err("multiaddress: Error encoding `" & part & "/" & path & "`")
return err("multiaddress: Error encoding `" & part & "/" &
path & "`")
break
elif proto.kind == Marker:
@ -801,8 +812,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.finish()
ok(res)
proc init*(mtype: typedesc[MultiAddress], data: openarray[byte]): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress],
data: openarray[byte]): MaResult[MultiAddress] =
## Initialize MultiAddress with array of bytes ``data``.
if len(data) == 0:
err("multiaddress: Address could not be empty!")
@ -836,10 +847,12 @@ proc init*(mtype: typedesc[MultiAddress],
var data = initVBuffer()
data.write(familyProto.mcodec)
var written = familyProto.coder.stringToBuffer($address, data)
doAssert written, "Merely writing a string to a buffer should always be possible"
doAssert written,
"Merely writing a string to a buffer should always be possible"
data.write(protoProto.mcodec)
written = protoProto.coder.stringToBuffer($port, data)
doAssert written, "Merely writing a string to a buffer should always be possible"
doAssert written,
"Merely writing a string to a buffer should always be possible"
data.finish()
MultiAddress(data: data)
@ -890,14 +903,16 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] =
else:
ok()
proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [Defect, ResultError[string]].} =
proc `&`*(m1, m2: MultiAddress): MultiAddress {.
raises: [Defect, ResultError[string]].} =
## Concatenates two addresses ``m1`` and ``m2``, and returns result.
##
## This procedure performs validation of concatenated result and can raise
## exception on error.
concat(m1, m2).tryGet()
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [Defect, ResultError[string]].} =
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.
raises: [Defect, ResultError[string]].} =
## Concatenates two addresses ``m1`` and ``m2``.
##
## This procedure performs validation of concatenated result and can raise
@ -1005,3 +1020,36 @@ proc `$`*(pat: MaPattern): string =
result = "(" & sub.join("|") & ")"
elif pat.operator == Eq:
result = $pat.value
proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
write(pb, field, value.data.buffer)
proc getField*(pb: var ProtoBuffer, field: int,
value: var MultiAddress): bool {.inline.} =
var buffer: seq[byte]
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
let ma = MultiAddress.init(buffer)
if ma.isOk():
value = ma.get()
true
else:
false
proc getRepeatedField*(pb: var ProtoBuffer, field: int,
value: var seq[MultiAddress]): bool {.inline.} =
var items: seq[seq[byte]]
value.setLen(0)
if not(getRepeatedField(pb, field, items)):
return false
if len(items) == 0:
return true
for item in items:
let ma = MultiAddress.init(item)
if ma.isOk():
value.add(ma.get())
else:
value.setLen(0)
return false

View File

@ -200,11 +200,12 @@ proc write*(vb: var VBuffer, pid: PeerID) {.inline.} =
## Write PeerID value ``peerid`` to buffer ``vb``.
vb.writeSeq(pid.data)
proc initProtoField*(index: int, pid: PeerID): ProtoField =
proc initProtoField*(index: int, pid: PeerID): ProtoField {.deprecated.} =
## Initialize ProtoField with PeerID ``value``.
result = initProtoField(index, pid.data)
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int {.
deprecated.} =
## Read ``PeerID`` from ProtoBuf's message and validate it.
var pid: PeerID
result = getLengthValue(data, field, pid.data)
@ -213,3 +214,21 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
result = -1
else:
value = pid
proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) =
## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
write(pb, field, pid.data)
proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool =
## Read ``PeerID`` from ProtoBuf's message and validate it
var buffer: seq[byte]
var peerId: PeerID
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if peerId.init(buffer):
pid = peerId
true
else:
false

View File

@ -11,7 +11,7 @@
{.push raises: [Defect].}
import ../varint
import ../varint, stew/endians2
const
MaxMessageSize* = 1'u shl 22
@ -32,10 +32,14 @@ type
offset*: int
length*: int
ProtoHeader* = object
wire*: ProtoFieldKind
index*: uint64
ProtoField* = object
## Protobuf's message field representation object
index: int
case kind: ProtoFieldKind
index*: int
case kind*: ProtoFieldKind
of Varint:
vint*: uint64
of Fixed64:
@ -47,13 +51,35 @@ type
of StartGroup, EndGroup:
discard
template protoHeader*(index: int, wire: ProtoFieldKind): uint =
## Get protobuf's field header integer for ``index`` and ``wire``.
((uint(index) shl 3) or cast[uint](wire))
ProtoResult {.pure.} = enum
VarintDecodeError,
MessageIncompleteError,
BufferOverflowError,
MessageSizeTooBigError,
NoError
template protoHeader*(field: ProtoField): uint =
ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
hint | hint32 | hint64 | float32 | float64
const
SupportedWireTypes* = {
int(ProtoFieldKind.Varint),
int(ProtoFieldKind.Fixed64),
int(ProtoFieldKind.Length),
int(ProtoFieldKind.Fixed32)
}
template checkFieldNumber*(i: int) =
doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999),
"Incorrect or reserved field number")
template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 =
## Get protobuf's field header integer for ``index`` and ``wire``.
((uint64(index) shl 3) or uint64(wire))
template getProtoHeader*(field: ProtoField): uint64 =
## Get protobuf's field header integer for ``field``.
((uint(field.index) shl 3) or cast[uint](field.kind))
((uint64(field.index) shl 3) or uint64(field.kind))
template toOpenArray*(pb: ProtoBuffer): untyped =
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)
@ -72,20 +98,20 @@ template getLen*(pb: ProtoBuffer): int =
proc vsizeof*(field: ProtoField): int {.inline.} =
## Returns number of bytes required to store protobuf's field ``field``.
result = vsizeof(protoHeader(field))
case field.kind
of ProtoFieldKind.Varint:
result += vsizeof(field.vint)
vsizeof(getProtoHeader(field)) + vsizeof(field.vint)
of ProtoFieldKind.Fixed64:
result += sizeof(field.vfloat64)
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat64)
of ProtoFieldKind.Fixed32:
result += sizeof(field.vfloat32)
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32)
of ProtoFieldKind.Length:
result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer)
vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) +
len(field.vbuffer)
else:
discard
0
proc initProtoField*(index: int, value: SomeVarint): ProtoField =
proc initProtoField*(index: int, value: SomeVarint): ProtoField {.deprecated.} =
## Initialize ProtoField with integer value.
result = ProtoField(kind: Varint, index: index)
when type(value) is uint64:
@ -93,26 +119,28 @@ proc initProtoField*(index: int, value: SomeVarint): ProtoField =
else:
result.vint = cast[uint64](value)
proc initProtoField*(index: int, value: bool): ProtoField =
proc initProtoField*(index: int, value: bool): ProtoField {.deprecated.} =
## Initialize ProtoField with integer value.
result = ProtoField(kind: Varint, index: index)
result.vint = byte(value)
proc initProtoField*(index: int, value: openarray[byte]): ProtoField =
proc initProtoField*(index: int,
value: openarray[byte]): ProtoField {.deprecated.} =
## Initialize ProtoField with bytes array.
result = ProtoField(kind: Length, index: index)
if len(value) > 0:
result.vbuffer = newSeq[byte](len(value))
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
proc initProtoField*(index: int, value: string): ProtoField =
proc initProtoField*(index: int, value: string): ProtoField {.deprecated.} =
## Initialize ProtoField with string.
result = ProtoField(kind: Length, index: index)
if len(value) > 0:
result.vbuffer = newSeq[byte](len(value))
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} =
proc initProtoField*(index: int,
value: ProtoBuffer): ProtoField {.deprecated, inline.} =
## Initialize ProtoField with nested message stored in ``value``.
##
## Note: This procedure performs shallow copy of ``value`` sequence.
@ -127,6 +155,13 @@ proc initProtoBuffer*(data: seq[byte], offset = 0,
result.offset = offset
result.options = options
proc initProtoBuffer*(data: openarray[byte], offset = 0,
options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with copy of ``data``.
result.buffer = @data
result.offset = offset
result.options = options
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
result.buffer = newSeqOfCap[byte](128)
@ -138,16 +173,134 @@ proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
result.offset = 10
elif {WithUint32LeLength, WithUint32BeLength} * options != {}:
# Our buffer will start from position 4, so we can store length of buffer
# in [0, 9].
# in [0, 3].
result.buffer.setLen(4)
result.offset = 4
proc write*(pb: var ProtoBuffer, field: ProtoField) =
proc write*[T: ProtoScalar](pb: var ProtoBuffer,
field: int, value: T) =
checkFieldNumber(field)
var length = 0
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) +
vsizeof(value)
let header = ProtoFieldKind.Varint
elif T is float32:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) +
sizeof(T)
let header = ProtoFieldKind.Fixed32
elif T is float64:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) +
sizeof(T)
let header = ProtoFieldKind.Fixed64
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, header))
doAssert(hres.isOk())
pb.offset += length
when (T is uint64) or (T is uint32) or (T is uint):
let vres = PB.putUVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let vres = putSVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
let u32 = cast[uint32](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
let u64 = cast[uint64](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int,
value: openarray[T]) =
checkFieldNumber(field)
var length = 0
let dlength =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
var res = 0
for item in value:
res += vsizeof(item)
res
elif (T is float32) or (T is float64):
len(value) * sizeof(T)
let header = getProtoHeader(field, ProtoFieldKind.Length)
let flength = vsizeof(header) + vsizeof(uint64(dlength)) + dlength
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length, header)
doAssert(hres.isOk())
pb.offset += length
length = 0
let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(dlength))
doAssert(lres.isOk())
pb.offset += length
for item in value:
when (T is uint64) or (T is uint32) or (T is uint):
length = 0
let vres = PB.putUVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
length = 0
let vres = PB.putSVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
let u32 = cast[uint32](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
let u64 = cast[uint64](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc write*[T: byte|char](pb: var ProtoBuffer, field: int,
value: openarray[T]) =
checkFieldNumber(field)
var length = 0
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) +
vsizeof(uint64(len(value))) + len(value)
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, ProtoFieldKind.Length))
doAssert(hres.isOk())
pb.offset += length
let lres = PB.putUVarint(pb.toOpenArray(), length,
uint64(len(value)))
doAssert(lres.isOk())
pb.offset += length
if len(value) > 0:
doAssert(pb.isEnough(len(value)))
copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value))
pb.offset += len(value)
proc write*(pb: var ProtoBuffer, field: int, value: ProtoBuffer) {.inline.} =
## Encode Protobuf's sub-message ``value`` and store it to protobuf's buffer
## ``pb`` with field number ``field``.
write(pb, field, value.buffer)
proc write*(pb: var ProtoBuffer, field: ProtoField) {.deprecated.} =
## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``.
var length = 0
var res: VarintResult[void]
pb.buffer.setLen(len(pb.buffer) + vsizeof(field))
res = PB.putUVarint(pb.toOpenArray(), length, protoHeader(field))
res = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field))
doAssert(res.isOk())
pb.offset += length
case field.kind
@ -199,31 +352,440 @@ proc finish*(pb: var ProtoBuffer) =
pb.offset = pos
elif WithUint32BeLength in pb.options:
let size = uint(len(pb.buffer) - 4)
pb.buffer[0] = byte((size shr 24) and 0xFF'u)
pb.buffer[1] = byte((size shr 16) and 0xFF'u)
pb.buffer[2] = byte((size shr 8) and 0xFF'u)
pb.buffer[3] = byte(size and 0xFF'u)
pb.buffer[0 ..< 4] = toBytesBE(uint32(size))
pb.offset = 4
elif WithUint32LeLength in pb.options:
let size = uint(len(pb.buffer) - 4)
pb.buffer[0] = byte(size and 0xFF'u)
pb.buffer[1] = byte((size shr 8) and 0xFF'u)
pb.buffer[2] = byte((size shr 16) and 0xFF'u)
pb.buffer[3] = byte((size shr 24) and 0xFF'u)
pb.buffer[0 ..< 4] = toBytesLE(uint32(size))
pb.offset = 4
else:
pb.offset = 0
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
var length = 0
var hdr = 0'u64
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
let index = uint64(hdr shr 3)
let wire = hdr and 0x07
if wire in SupportedWireTypes:
data.offset += length
header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire))
true
else:
false
else:
false
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
case header.wire
of ProtoFieldKind.Varint:
var length = 0
var value = 0'u64
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
true
else:
false
of ProtoFieldKind.Fixed32:
if data.isEnough(sizeof(uint32)):
data.offset += sizeof(uint32)
true
else:
false
of ProtoFieldKind.Fixed64:
if data.isEnough(sizeof(uint64)):
data.offset += sizeof(uint64)
true
else:
false
of ProtoFieldKind.Length:
var length = 0
var bsize = 0'u64
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
data.offset += int(bsize)
true
else:
false
else:
false
else:
false
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
false
proc getValue[T: ProtoScalar](data: var ProtoBuffer,
header: ProtoHeader,
outval: var T): ProtoResult =
when (T is uint64) or (T is uint32) or (T is uint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
var value = T(0)
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
outval = value
ProtoResult.NoError
else:
ProtoResult.VarintDecodeError
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
var value = T(0)
if getSVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
outval = value
ProtoResult.NoError
else:
ProtoResult.VarintDecodeError
elif T is float32:
doAssert(header.wire == ProtoFieldKind.Fixed32)
if data.isEnough(sizeof(float32)):
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
data.offset += sizeof(float32)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
elif T is float64:
doAssert(header.wire == ProtoFieldKind.Fixed64)
if data.isEnough(sizeof(float64)):
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
data.offset += sizeof(float64)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var openarray[T],
outLength: var int): ProtoResult =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
outLength = 0
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
outLength = int(bsize)
if len(outBytes) >= int(bsize):
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ProtoResult.NoError
else:
# Buffer overflow should not be critical failure
data.offset += int(bsize)
ProtoResult.BufferOverflowError
else:
ProtoResult.MessageIncompleteError
else:
ProtoResult.MessageSizeTooBigError
else:
ProtoResult.VarintDecodeError
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var T): ProtoResult =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
outBytes.setLen(0)
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
outBytes.setLen(bsize)
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
else:
ProtoResult.MessageSizeTooBigError
else:
ProtoResult.VarintDecodeError
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var T): bool =
checkFieldNumber(field)
var value: T
var res = false
var pb = data
output = T(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output = T(0)
return false
let wireCheck =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
header.wire == ProtoFieldKind.Varint
elif T is float32:
header.wire == ProtoFieldKind.Fixed32
elif T is float64:
header.wire == ProtoFieldKind.Fixed64
if header.index == uint64(field):
if wireCheck:
let r = getValue(pb, header, value)
case r
of ProtoResult.NoError:
res = true
output = value
else:
return false
else:
# We are ignoring wire types different from what we expect, because it
# is how `protoc` is working.
if not(skipValue(pb, header)):
output = T(0)
return false
else:
if not(skipValue(pb, header)):
output = T(0)
return false
res
proc getField*[T: byte|char](data: ProtoBuffer, field: int,
output: var openarray[T],
outlen: var int): bool =
checkFieldNumber(field)
var pb = data
var res = false
outlen = 0
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output, outlen)
case r
of ProtoResult.NoError:
res = true
of ProtoResult.BufferOverflowError:
# Buffer overflow error is not critical error, we still can get
# field values with proper size.
discard
else:
if len(output) > 0:
zeroMem(addr output[0], len(output))
return false
else:
# We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working.
if not(skipValue(pb, header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
else:
if not(skipValue(pb, header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
res
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var T): bool =
checkFieldNumber(field)
var res = false
var pb = data
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output)
case r
of ProtoResult.NoError:
res = true
of ProtoResult.BufferOverflowError:
# Buffer overflow error is not critical error, we still can get
# field values with proper size.
discard
else:
output.setLen(0)
return false
else:
# We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working.
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
res
proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {.
inline.} =
var buffer: seq[byte]
if pb.getField(field, buffer):
output = initProtoBuffer(buffer)
true
else:
false
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
var item: T
let r = getValue(pb, header, item)
case r
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer,
field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
ProtoFieldKind.Fixed64}:
var item: T
let r = getValue(pb, header, item)
case r
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
var arritem: seq[byte]
let rarr = getValue(pb, header, arritem)
case rarr
of ProtoResult.NoError:
var pbarr = initProtoBuffer(arritem)
let itemHeader =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
ProtoHeader(wire: ProtoFieldKind.Varint)
elif T is float32:
ProtoHeader(wire: ProtoFieldKind.Fixed32)
elif T is float64:
ProtoHeader(wire: ProtoFieldKind.Fixed64)
while not(pbarr.isEmpty()):
var item: T
let res = getValue(pbarr, itemHeader, item)
case res
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getVarintValue*(data: var ProtoBuffer, field: int,
value: var SomeVarint): int =
value: var SomeVarint): int {.deprecated.} =
## Get value of `Varint` type.
var length = 0
var header = 0'u64
var soffset = data.offset
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, header).isOk():
data.offset += length
if header == protoHeader(field, Varint):
if header == getProtoHeader(field, Varint):
if not data.isEmpty():
when type(value) is int32 or type(value) is int64 or type(value) is int:
let res = getSVarint(data.toOpenArray(), length, value)
@ -237,7 +799,7 @@ proc getVarintValue*(data: var ProtoBuffer, field: int,
data.offset = soffset
proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
buffer: var T): int =
buffer: var T): int {.deprecated.} =
## Get value of `Length` type.
var length = 0
var header = 0'u64
@ -245,10 +807,12 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
var soffset = data.offset
result = -1
buffer.setLen(0)
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, header).isOk():
data.offset += length
if header == protoHeader(field, Length):
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, ssize).isOk():
if header == getProtoHeader(field, Length):
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, ssize).isOk():
data.offset += length
if ssize <= MaxMessageSize and data.isEnough(int(ssize)):
buffer.setLen(ssize)
@ -262,16 +826,16 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
data.offset = soffset
proc getBytes*(data: var ProtoBuffer, field: int,
buffer: var seq[byte]): int {.inline.} =
buffer: var seq[byte]): int {.deprecated, inline.} =
## Get value of `Length` type as bytes.
result = getLengthValue(data, field, buffer)
proc getString*(data: var ProtoBuffer, field: int,
buffer: var string): int {.inline.} =
buffer: var string): int {.deprecated, inline.} =
## Get value of `Length` type as string.
result = getLengthValue(data, field, buffer)
proc enterSubmessage*(pb: var ProtoBuffer): int =
proc enterSubmessage*(pb: var ProtoBuffer): int {.deprecated.} =
## Processes protobuf's sub-message and adjust internal offset to enter
## inside of sub-message. Returns field index of sub-message field or
## ``0`` on error.
@ -280,10 +844,12 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
var msize = 0'u64
var soffset = pb.offset
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, header).isOk():
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
length, header).isOk():
pb.offset += length
if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length):
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, msize).isOk():
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
length, msize).isOk():
pb.offset += length
if msize <= MaxMessageSize and pb.isEnough(int(msize)):
pb.length = int(msize)
@ -292,7 +858,7 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
# Restore offset on error
pb.offset = soffset
proc skipSubmessage*(pb: var ProtoBuffer) =
proc skipSubmessage*(pb: var ProtoBuffer) {.deprecated.} =
## Skip current protobuf's sub-message and adjust internal offset to the
## end of sub-message.
doAssert(pb.length != 0)

View File

@ -47,61 +47,49 @@ type
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
result = initProtoBuffer()
result.write(initProtoField(1, peerInfo.publicKey.get().getBytes().tryGet()))
result.write(1, peerInfo.publicKey.get().getBytes().tryGet())
for ma in peerInfo.addrs:
result.write(initProtoField(2, ma.data.buffer))
result.write(2, ma.data.buffer)
for proto in peerInfo.protocols:
result.write(initProtoField(3, proto))
result.write(3, proto)
result.write(initProtoField(4, observedAddr.data.buffer))
result.write(4, observedAddr.data.buffer)
let protoVersion = ProtoVersion
result.write(initProtoField(5, protoVersion))
result.write(5, protoVersion)
let agentVersion = AgentVersion
result.write(initProtoField(6, agentVersion))
result.write(6, agentVersion)
result.finish()
proc decodeMsg*(buf: seq[byte]): IdentifyInfo =
var pb = initProtoBuffer(buf)
result.pubKey = none(PublicKey)
var pubKey: PublicKey
if pb.getValue(1, pubKey) > 0:
if pb.getField(1, pubKey):
trace "read public key from message", pubKey = ($pubKey).shortLog
result.pubKey = some(pubKey)
result.addrs = newSeq[MultiAddress]()
var address = newSeq[byte]()
while pb.getBytes(2, address) > 0:
if len(address) != 0:
var copyaddr = address
var ma = MultiAddress.init(copyaddr).tryGet()
result.addrs.add(ma)
trace "read address bytes from message", address = ma
address.setLen(0)
if pb.getRepeatedField(2, result.addrs):
trace "read addresses from message", addresses = result.addrs
var proto = ""
while pb.getString(3, proto) > 0:
trace "read proto from message", proto = proto
result.protos.add(proto)
proto = ""
if pb.getRepeatedField(3, result.protos):
trace "read protos from message", protocols = result.protos
var observableAddr = newSeq[byte]()
if pb.getBytes(4, observableAddr) > 0: # attempt to read the observed addr
var ma = MultiAddress.init(observableAddr).tryGet()
trace "read observedAddr from message", address = ma
result.observedAddr = some(ma)
var observableAddr: MultiAddress
if pb.getField(4, observableAddr):
trace "read observableAddr from message", address = observableAddr
result.observedAddr = some(observableAddr)
var protoVersion = ""
if pb.getString(5, protoVersion) > 0:
if pb.getField(5, protoVersion):
trace "read protoVersion from message", protoVersion = protoVersion
result.protoVersion = some(protoVersion)
var agentVersion = ""
if pb.getString(6, agentVersion) > 0:
if pb.getField(6, agentVersion):
trace "read agentVersion from message", agentVersion = agentVersion
result.agentVersion = some(agentVersion)

View File

@ -32,9 +32,7 @@ func defaultMsgIdProvider*(m: Message): string =
byteutils.toHex(m.seqno) & m.fromPeer.pretty
proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} =
var buff = initProtoBuffer()
encodeMessage(msg, buff)
p.privateKey.sign(PubSubPrefix & buff.buffer).tryGet().getBytes()
p.privateKey.sign(PubSubPrefix & encodeMessage(msg)).tryGet().getBytes()
proc verify*(m: Message, p: PeerInfo): bool =
if m.signature.len > 0 and m.key.len > 0:
@ -42,14 +40,11 @@ proc verify*(m: Message, p: PeerInfo): bool =
msg.signature = @[]
msg.key = @[]
var buff = initProtoBuffer()
encodeMessage(msg, buff)
var remote: Signature
var key: PublicKey
if remote.init(m.signature) and key.init(m.key):
trace "verifying signature", remoteSignature = remote
result = remote.verify(PubSubPrefix & buff.buffer, key)
result = remote.verify(PubSubPrefix & encodeMessage(msg), key)
if result:
libp2p_pubsub_sig_verify_success.inc()

View File

@ -14,265 +14,247 @@ import messages,
../../../utility,
../../../protobuf/minprotobuf
proc encodeGraft*(graft: ControlGraft, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, graft.topicID))
proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) =
var ipb = initProtoBuffer()
ipb.write(1, graft.topicID)
ipb.finish()
pb.write(field, ipb)
proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} =
trace "decoding graft msg", buffer = pb.buffer.shortLog
while true:
var topic: string
if pb.getString(1, topic) < 0:
break
proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) =
var ipb = initProtoBuffer()
ipb.write(1, prune.topicID)
ipb.finish()
pb.write(field, ipb)
trace "read topic field from graft msg", topicID = topic
result.add(ControlGraft(topicID: topic))
proc encodePrune*(prune: ControlPrune, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, prune.topicID))
proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} =
trace "decoding prune msg"
while true:
var topic: string
if pb.getString(1, topic) < 0:
break
trace "read topic field from prune msg", topicID = topic
result.add(ControlPrune(topicID: topic))
proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, ihave.topicID))
proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) =
var ipb = initProtoBuffer()
ipb.write(1, ihave.topicID)
for mid in ihave.messageIDs:
pb.write(initProtoField(2, mid))
ipb.write(2, mid)
ipb.finish()
pb.write(field, ipb)
proc decodeIHave*(pb: var ProtoBuffer): seq[ControlIHave] {.gcsafe.} =
trace "decoding ihave msg"
while true:
var control: ControlIHave
if pb.getString(1, control.topicID) < 0:
trace "topic field missing from ihave msg"
break
trace "read topic field", topicID = control.topicID
while true:
var mid: string
if pb.getString(2, mid) < 0:
break
trace "read messageID field", mid = mid
control.messageIDs.add(mid)
result.add(control)
proc encodeIWant*(iwant: ControlIWant, pb: var ProtoBuffer) {.gcsafe.} =
proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) =
var ipb = initProtoBuffer()
for mid in iwant.messageIDs:
pb.write(initProtoField(1, mid))
ipb.write(1, mid)
if len(ipb.buffer) > 0:
ipb.finish()
pb.write(field, ipb)
proc decodeIWant*(pb: var ProtoBuffer): seq[ControlIWant] {.gcsafe.} =
trace "decoding iwant msg"
proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) =
var ipb = initProtoBuffer()
for ihave in control.ihave:
ipb.write(1, ihave)
for iwant in control.iwant:
ipb.write(2, iwant)
for graft in control.graft:
ipb.write(3, graft)
for prune in control.prune:
ipb.write(4, prune)
if len(ipb.buffer) > 0:
ipb.finish()
pb.write(field, ipb)
var control: ControlIWant
while true:
var mid: string
if pb.getString(1, mid) < 0:
break
control.messageIDs.add(mid)
trace "read messageID field", mid = mid
result.add(control)
proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} =
if control.ihave.len > 0:
var ihave = initProtoBuffer()
for h in control.ihave:
h.encodeIHave(ihave)
# write messages to protobuf
if ihave.buffer.len > 0:
ihave.finish()
pb.write(initProtoField(1, ihave))
if control.iwant.len > 0:
var iwant = initProtoBuffer()
for w in control.iwant:
w.encodeIWant(iwant)
# write messages to protobuf
if iwant.buffer.len > 0:
iwant.finish()
pb.write(initProtoField(2, iwant))
if control.graft.len > 0:
var graft = initProtoBuffer()
for g in control.graft:
g.encodeGraft(graft)
# write messages to protobuf
if graft.buffer.len > 0:
graft.finish()
pb.write(initProtoField(3, graft))
if control.prune.len > 0:
var prune = initProtoBuffer()
for p in control.prune:
p.encodePrune(prune)
# write messages to protobuf
if prune.buffer.len > 0:
prune.finish()
pb.write(initProtoField(4, prune))
proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} =
trace "decoding control submessage"
var control: ControlMessage
while true:
var field = pb.enterSubMessage()
trace "processing submessage", field = field
case field:
of 0:
trace "no submessage found in Control msg"
break
of 1:
control.ihave &= pb.decodeIHave()
of 2:
control.iwant &= pb.decodeIWant()
of 3:
control.graft &= pb.decodeGraft()
of 4:
control.prune &= pb.decodePrune()
else:
raise newException(CatchableError, "message type not recognized")
if result.isNone:
result = some(control)
proc encodeSubs*(subs: SubOpts, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, subs.subscribe))
pb.write(initProtoField(2, subs.topic))
proc decodeSubs*(pb: var ProtoBuffer): seq[SubOpts] {.gcsafe.} =
while true:
var subOpt: SubOpts
var subscr: uint
discard pb.getVarintValue(1, subscr)
subOpt.subscribe = cast[bool](subscr)
trace "read subscribe field", subscribe = subOpt.subscribe
if pb.getString(2, subOpt.topic) < 0:
break
trace "read subscribe field", topicName = subOpt.topic
result.add(subOpt)
trace "got subscriptions", subscriptions = result
proc encodeMessage*(msg: Message, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, msg.fromPeer.getBytes()))
pb.write(initProtoField(2, msg.data))
pb.write(initProtoField(3, msg.seqno))
for t in msg.topicIDs:
pb.write(initProtoField(4, t))
if msg.signature.len > 0:
pb.write(initProtoField(5, msg.signature))
if msg.key.len > 0:
pb.write(initProtoField(6, msg.key))
proc write*(pb: var ProtoBuffer, field: int, subs: SubOpts) =
var ipb = initProtoBuffer()
ipb.write(1, uint64(subs.subscribe))
ipb.write(2, subs.topic)
ipb.finish()
pb.write(field, ipb)
proc encodeMessage*(msg: Message): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, msg.fromPeer)
pb.write(2, msg.data)
pb.write(3, msg.seqno)
for topic in msg.topicIDs:
pb.write(4, topic)
if len(msg.signature) > 0:
pb.write(5, msg.signature)
if len(msg.key) > 0:
pb.write(6, msg.key)
pb.finish()
pb.buffer
proc decodeMessages*(pb: var ProtoBuffer): seq[Message] {.gcsafe.} =
# TODO: which of this fields are really optional?
while true:
var msg: Message
var fromPeer: seq[byte]
if pb.getBytes(1, fromPeer) < 0:
break
try:
msg.fromPeer = PeerID.init(fromPeer).tryGet()
except CatchableError as err:
debug "Invalid fromPeer in message", msg = err.msg
break
proc write*(pb: var ProtoBuffer, field: int, msg: Message) =
pb.write(field, encodeMessage(msg))
trace "read message field", fromPeer = msg.fromPeer.pretty
proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} =
trace "decodeGraft: decoding message"
var control = ControlGraft()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodeGraft: read topicId", topic_id = topicId
else:
trace "decodeGraft: topicId is missing"
control
if pb.getBytes(2, msg.data) < 0:
break
trace "read message field", data = msg.data.shortLog
proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} =
trace "decodePrune: decoding message"
var control = ControlPrune()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodePrune: read topicId", topic_id = topicId
else:
trace "decodePrune: topicId is missing"
control
if pb.getBytes(3, msg.seqno) < 0:
break
trace "read message field", seqno = msg.seqno.shortLog
proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} =
trace "decodeIHave: decoding message"
var control = ControlIHave()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodeIHave: read topicId", topic_id = topicId
else:
trace "decodeIHave: topicId is missing"
if pb.getRepeatedField(2, control.messageIDs):
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
else:
trace "decodeIHave: no messageIDs"
control
var topic: string
while true:
if pb.getString(4, topic) < 0:
break
msg.topicIDs.add(topic)
trace "read message field", topicName = topic
topic = ""
proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} =
trace "decodeIWant: decoding message"
var control = ControlIWant()
if pb.getRepeatedField(1, control.messageIDs):
trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
else:
trace "decodeIWant: no messageIDs"
discard pb.getBytes(5, msg.signature)
trace "read message field", signature = msg.signature.shortLog
proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} =
trace "decodeControl: decoding message"
var buffer: seq[byte]
if pb.getField(3, buffer):
var control: ControlMessage
var cpb = initProtoBuffer(buffer)
var ihavepbs: seq[seq[byte]]
var iwantpbs: seq[seq[byte]]
var graftpbs: seq[seq[byte]]
var prunepbs: seq[seq[byte]]
discard pb.getBytes(6, msg.key)
trace "read message field", key = msg.key.shortLog
discard cpb.getRepeatedField(1, ihavepbs)
discard cpb.getRepeatedField(2, iwantpbs)
discard cpb.getRepeatedField(3, graftpbs)
discard cpb.getRepeatedField(4, prunepbs)
result.add(msg)
for item in ihavepbs:
control.ihave.add(decodeIHave(initProtoBuffer(item)))
for item in iwantpbs:
control.iwant.add(decodeIWant(initProtoBuffer(item)))
for item in graftpbs:
control.graft.add(decodeGraft(initProtoBuffer(item)))
for item in prunepbs:
control.prune.add(decodePrune(initProtoBuffer(item)))
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} =
result = initProtoBuffer()
trace "encoding msg: ", msg = msg.shortLog
trace "decodeControl: "
some(control)
else:
none[ControlMessage]()
if msg.subscriptions.len > 0:
for s in msg.subscriptions:
var subs = initProtoBuffer()
encodeSubs(s, subs)
proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} =
trace "decodeSubscription: decoding message"
var subflag: uint64
var sub = SubOpts()
if pb.getField(1, subflag):
sub.subscribe = bool(subflag)
trace "decodeSubscription: read subscribe", subscribe = subflag
else:
trace "decodeSubscription: subscribe is missing"
if pb.getField(2, sub.topic):
trace "decodeSubscription: read topic", topic = sub.topic
else:
trace "decodeSubscription: topic is missing"
# write subscriptions to protobuf
if subs.buffer.len > 0:
subs.finish()
result.write(initProtoField(1, subs))
sub
if msg.messages.len > 0:
var messages = initProtoBuffer()
for m in msg.messages:
encodeMessage(m, messages)
proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} =
trace "decodeSubscriptions: decoding message"
var subpbs: seq[seq[byte]]
var subs: seq[SubOpts]
if pb.getRepeatedField(1, subpbs):
trace "decodeSubscriptions: read subscriptions", count = len(subpbs)
for item in subpbs:
let sub = decodeSubscription(initProtoBuffer(item))
subs.add(sub)
# write messages to protobuf
if messages.buffer.len > 0:
messages.finish()
result.write(initProtoField(2, messages))
if len(subs) == 0:
trace "decodeSubscription: no subscriptions found"
if msg.control.isSome:
var control = initProtoBuffer()
msg.control.get.encodeControl(control)
subs
# write messages to protobuf
if control.buffer.len > 0:
control.finish()
result.write(initProtoField(3, control))
proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} =
trace "decodeMessage: decoding message"
var msg: Message
if pb.getField(1, msg.fromPeer):
trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty()
else:
trace "decodeMessage: fromPeer is missing"
if result.buffer.len > 0:
result.finish()
if pb.getField(2, msg.data):
trace "decodeMessage: read data", data = msg.data.shortLog()
else:
trace "decodeMessage: data is missing"
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} =
if pb.getField(3, msg.seqno):
trace "decodeMessage: read seqno", seqno = msg.data.shortLog()
else:
trace "decodeMessage: seqno is missing"
if pb.getRepeatedField(4, msg.topicIDs):
trace "decodeMessage: read topics", topic_ids = msg.topicIDs
else:
trace "decodeMessage: topics are missing"
if pb.getField(5, msg.signature):
trace "decodeMessage: read signature", signature = msg.signature.shortLog()
else:
trace "decodeMessage: signature is missing"
if pb.getField(6, msg.key):
trace "decodeMessage: read public key", key = msg.key.shortLog()
else:
trace "decodeMessage: public key is missing"
msg
proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} =
trace "decodeMessages: decoding message"
var msgpbs: seq[seq[byte]]
var msgs: seq[Message]
if pb.getRepeatedField(2, msgpbs):
trace "decodeMessages: read messages", count = len(msgpbs)
for item in msgpbs:
let msg = decodeMessage(initProtoBuffer(item))
msgs.add(msg)
if len(msgs) == 0:
trace "decodeMessages: no messages found"
msgs
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
trace "encodeRpcMsg: encoding message", msg = msg.shortLog()
var pb = initProtoBuffer()
for item in msg.subscriptions:
pb.write(1, item)
for item in msg.messages:
pb.write(2, item)
if msg.control.isSome():
pb.write(3, msg.control.get())
if len(pb.buffer) > 0:
pb.finish()
result = pb
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg =
trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
var pb = initProtoBuffer(msg)
var rpcMsg: RPCMsg
rpcMsg.messages = pb.decodeMessages()
rpcMsg.subscriptions = pb.decodeSubscriptions()
rpcMsg.control = pb.decodeControl()
while true:
# decode SubOpts array
var field = pb.enterSubMessage()
trace "processing submessage", field = field
case field:
of 0:
trace "no submessage found in RPC msg"
break
of 1:
result.subscriptions &= pb.decodeSubs()
of 2:
result.messages &= pb.decodeMessages()
of 3:
result.control = pb.decodeControl()
else:
raise newException(CatchableError, "message type not recognized")
rpcMsg

View File

@ -430,8 +430,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
var
libp2pProof = initProtoBuffer()
libp2pProof.write(initProtoField(1, p.localPublicKey))
libp2pProof.write(initProtoField(2, signedPayload.getBytes()))
libp2pProof.write(1, p.localPublicKey)
libp2pProof.write(2, signedPayload.getBytes())
# data field also there but not used!
libp2pProof.finish()
@ -449,9 +449,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
remoteSig: Signature
remoteSigBytes: seq[byte]
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0:
if not(remoteProof.getField(1, remotePubKeyBytes)):
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0:
if not(remoteProof.getField(2, remoteSigBytes)):
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remotePubKey.init(remotePubKeyBytes):

677
tests/testminprotobuf.nim Normal file
View File

@ -0,0 +1,677 @@
## Nim-Libp2p
## Copyright (c) 2018 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import unittest
import ../libp2p/protobuf/minprotobuf
import stew/byteutils, strutils
when defined(nimHasUsed): {.used.}
suite "MinProtobuf test suite":
const VarintVectors = [
"0800", "0801", "08ffffffff07", "08ffffffff0f", "08ffffffffffffffff7f",
"08ffffffffffffffffff01"
]
const VarintValues = [
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xFFFF_FFFF'u64,
0x7FFF_FFFF_FFFF_FFFF'u64, 0xFFFF_FFFF_FFFF_FFFF'u64
]
const Fixed32Vectors = [
"0d00000000", "0d01000000", "0dffffff7f", "0dddccbbaa", "0dffffffff"
]
const Fixed32Values = [
0x0'u32, 0x1'u32, 0x7FFF_FFFF'u32, 0xAABB_CCDD'u32, 0xFFFF_FFFF'u32
]
const Fixed64Vectors = [
"090000000000000000", "090100000000000000", "09ffffff7f00000000",
"09ddccbbaa00000000", "09ffffffff00000000", "09ffffffffffffff7f",
"099988ffeeddccbbaa", "09ffffffffffffffff"
]
const Fixed64Values = [
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xAABB_CCDD'u64, 0xFFFF_FFFF'u64,
0x7FFF_FFFF_FFFF_FFFF'u64, 0xAABB_CCDD_EEFF_8899'u64,
0xFFFF_FFFF_FFFF_FFFF'u64
]
const LengthVectors = [
"0a00", "0a0161", "0a026162", "0a0461626364", "0a086162636465666768"
]
const LengthValues = [
"", "a", "ab", "abcd", "abcdefgh"
]
## This vector values was tested with `protoc` and related proto file.
## syntax = "proto2";
## message testmsg {
## repeated uint64 d = 1 [packed=true];
## repeated uint64 d = 2 [packed=true];
## }
const PackedVarintVector =
"0a1f0001ffffffff07ffffffff0fffffffffffffffff7fffffffffffffffffff0112020001"
## syntax = "proto2";
## message testmsg {
## repeated sfixed32 d = 1 [packed=true];
## repeated sfixed32 d = 2 [packed=true];
## }
const PackedFixed32Vector =
"0a140000000001000000ffffff7fddccbbaaffffffff12080000000001000000"
## syntax = "proto2";
## message testmsg {
## repeated sfixed64 d = 1 [packed=true];
## repeated sfixed64 d = 2 [packed=true];
## }
const PackedFixed64Vector =
"""0a4000000000000000000100000000000000ffffff7f00000000ddccbbaa00000000
ffffffff00000000ffffffffffffff7f9988ffeeddccbbaaffffffffffffffff1210
00000000000000000100000000000000"""
proc getVarintEncodedValue(value: uint64): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getVarintDecodedValue(data: openarray[byte]): uint64 =
var value: uint64
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
value
proc getFixed32EncodedValue(value: float32): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getFixed32DecodedValue(data: openarray[byte]): uint32 =
var value: float32
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
cast[uint32](value)
proc getFixed64EncodedValue(value: float64): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getFixed64DecodedValue(data: openarray[byte]): uint64 =
var value: float64
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
cast[uint64](value)
proc getLengthEncodedValue(value: string): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getLengthEncodedValue(value: seq[byte]): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getLengthDecodedValue(data: openarray[byte]): string =
var value = newString(len(data))
var valueLen = 0
var pb = initProtoBuffer(data)
let res = pb.getField(1, value, valueLen)
doAssert(res)
value.setLen(valueLen)
value
proc isFullZero[T: byte|char](data: openarray[T]): bool =
for ch in data:
if int(ch) != 0:
return false
return true
proc corruptHeader(data: var openarray[byte], index: int) =
var values = [3, 4, 6]
data[0] = data[0] and 0xF8'u8
data[0] = data[0] or byte(values[index mod len(values)])
test "[varint] edge values test":
for i in 0 ..< len(VarintValues):
let data = getVarintEncodedValue(VarintValues[i])
check:
toHex(data) == VarintVectors[i]
getVarintDecodedValue(data) == VarintValues[i]
test "[varint] mixing many values with same field number test":
for i in 0 ..< len(VarintValues):
var pb = initProtoBuffer()
for k in 0 ..< len(VarintValues):
let index = (i + k + 1) mod len(VarintValues)
pb.write(1, VarintValues[index])
pb.finish()
check getVarintDecodedValue(pb.buffer) == VarintValues[i]
test "[varint] incorrect values test":
for i in 0 ..< len(VarintValues):
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[varint] non-existent field test":
for i in 0 ..< len(VarintValues):
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == 0'u64
test "[varint] corrupted header test":
for i in 0 ..< len(VarintValues):
for k in 0 ..< 3:
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[varint] empty buffer test":
var value: uint64
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == 0'u64
test "[varint] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, VarintValues[1])
pb1.write(1, VarintValues[2])
pb1.write(2, VarintValues[3])
pb1.write(1, VarintValues[4])
pb1.write(1, VarintValues[5])
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[uint64]
var fieldarr2: seq[uint64]
var fieldarr3: seq[uint64]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
fieldarr1[0] == VarintValues[1]
fieldarr1[1] == VarintValues[2]
fieldarr1[2] == VarintValues[4]
fieldarr1[3] == VarintValues[5]
fieldarr2[0] == VarintValues[3]
test "[varint] Repeated packed field test":
var pb1 = initProtoBuffer()
pb1.writePacked(1, VarintValues)
pb1.writePacked(2, VarintValues[0 .. 1])
pb1.finish()
check:
toHex(pb1.buffer) == PackedVarintVector
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[uint64]
var fieldarr2: seq[uint64]
var fieldarr3: seq[uint64]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 6
fieldarr1[0] == VarintValues[0]
fieldarr1[1] == VarintValues[1]
fieldarr1[2] == VarintValues[2]
fieldarr1[3] == VarintValues[3]
fieldarr1[4] == VarintValues[4]
fieldarr1[5] == VarintValues[5]
fieldarr2[0] == VarintValues[0]
fieldarr2[1] == VarintValues[1]
test "[fixed32] edge values test":
for i in 0 ..< len(Fixed32Values):
let data = getFixed32EncodedValue(cast[float32](Fixed32Values[i]))
check:
toHex(data) == Fixed32Vectors[i]
getFixed32DecodedValue(data) == Fixed32Values[i]
test "[fixed32] mixing many values with same field number test":
for i in 0 ..< len(Fixed32Values):
var pb = initProtoBuffer()
for k in 0 ..< len(Fixed32Values):
let index = (i + k + 1) mod len(Fixed32Values)
pb.write(1, cast[float32](Fixed32Values[index]))
pb.finish()
check getFixed32DecodedValue(pb.buffer) == Fixed32Values[i]
test "[fixed32] incorrect values test":
for i in 0 ..< len(Fixed32Values):
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed32] non-existent field test":
for i in 0 ..< len(Fixed32Values):
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == float32(0)
test "[fixed32] corrupted header test":
for i in 0 ..< len(Fixed32Values):
for k in 0 ..< 3:
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed32] empty buffer test":
var value: float32
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == float32(0)
test "[fixed32] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, cast[float32](Fixed32Values[0]))
pb1.write(1, cast[float32](Fixed32Values[1]))
pb1.write(2, cast[float32](Fixed32Values[2]))
pb1.write(1, cast[float32](Fixed32Values[3]))
pb1.write(1, cast[float32](Fixed32Values[4]))
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float32]
var fieldarr2: seq[float32]
var fieldarr3: seq[float32]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[uint32](fieldarr1[0]) == Fixed64Values[0]
cast[uint32](fieldarr1[1]) == Fixed64Values[1]
cast[uint32](fieldarr1[2]) == Fixed64Values[3]
cast[uint32](fieldarr1[3]) == Fixed64Values[4]
cast[uint32](fieldarr2[0]) == Fixed64Values[2]
test "[fixed32] Repeated packed field test":
var pb1 = initProtoBuffer()
var values = newSeq[float32](len(Fixed32Values))
for i in 0 ..< len(values):
values[i] = cast[float32](Fixed32Values[i])
pb1.writePacked(1, values)
pb1.writePacked(2, values[0 .. 1])
pb1.finish()
check:
toHex(pb1.buffer) == PackedFixed32Vector
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float32]
var fieldarr2: seq[float32]
var fieldarr3: seq[float32]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 5
cast[uint32](fieldarr1[0]) == Fixed32Values[0]
cast[uint32](fieldarr1[1]) == Fixed32Values[1]
cast[uint32](fieldarr1[2]) == Fixed32Values[2]
cast[uint32](fieldarr1[3]) == Fixed32Values[3]
cast[uint32](fieldarr1[4]) == Fixed32Values[4]
cast[uint32](fieldarr2[0]) == Fixed32Values[0]
cast[uint32](fieldarr2[1]) == Fixed32Values[1]
test "[fixed64] edge values test":
for i in 0 ..< len(Fixed64Values):
let data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
check:
toHex(data) == Fixed64Vectors[i]
getFixed64DecodedValue(data) == Fixed64Values[i]
test "[fixed64] mixing many values with same field number test":
for i in 0 ..< len(Fixed64Values):
var pb = initProtoBuffer()
for k in 0 ..< len(Fixed64Values):
let index = (i + k + 1) mod len(Fixed64Values)
pb.write(1, cast[float64](Fixed64Values[index]))
pb.finish()
check getFixed64DecodedValue(pb.buffer) == Fixed64Values[i]
test "[fixed64] incorrect values test":
for i in 0 ..< len(Fixed64Values):
var value: float32
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed64] non-existent field test":
for i in 0 ..< len(Fixed64Values):
var value: float64
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == float64(0)
test "[fixed64] corrupted header test":
for i in 0 ..< len(Fixed64Values):
for k in 0 ..< 3:
var value: float64
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed64] empty buffer test":
var value: float64
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == float64(0)
test "[fixed64] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, cast[float64](Fixed64Values[2]))
pb1.write(1, cast[float64](Fixed64Values[3]))
pb1.write(2, cast[float64](Fixed64Values[4]))
pb1.write(1, cast[float64](Fixed64Values[5]))
pb1.write(1, cast[float64](Fixed64Values[6]))
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float64]
var fieldarr2: seq[float64]
var fieldarr3: seq[float64]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[uint64](fieldarr1[0]) == Fixed64Values[2]
cast[uint64](fieldarr1[1]) == Fixed64Values[3]
cast[uint64](fieldarr1[2]) == Fixed64Values[5]
cast[uint64](fieldarr1[3]) == Fixed64Values[6]
cast[uint64](fieldarr2[0]) == Fixed64Values[4]
test "[fixed64] Repeated packed field test":
var pb1 = initProtoBuffer()
var values = newSeq[float64](len(Fixed64Values))
for i in 0 ..< len(values):
values[i] = cast[float64](Fixed64Values[i])
pb1.writePacked(1, values)
pb1.writePacked(2, values[0 .. 1])
pb1.finish()
let expect = PackedFixed64Vector.multiReplace(("\n", ""), (" ", ""))
check:
toHex(pb1.buffer) == expect
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float64]
var fieldarr2: seq[float64]
var fieldarr3: seq[float64]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 8
cast[uint64](fieldarr1[0]) == Fixed64Values[0]
cast[uint64](fieldarr1[1]) == Fixed64Values[1]
cast[uint64](fieldarr1[2]) == Fixed64Values[2]
cast[uint64](fieldarr1[3]) == Fixed64Values[3]
cast[uint64](fieldarr1[4]) == Fixed64Values[4]
cast[uint64](fieldarr1[5]) == Fixed64Values[5]
cast[uint64](fieldarr1[6]) == Fixed64Values[6]
cast[uint64](fieldarr1[7]) == Fixed64Values[7]
cast[uint64](fieldarr2[0]) == Fixed64Values[0]
cast[uint64](fieldarr2[1]) == Fixed64Values[1]
test "[length] edge values test":
for i in 0 ..< len(LengthValues):
let data1 = getLengthEncodedValue(LengthValues[i])
let data2 = getLengthEncodedValue(cast[seq[byte]](LengthValues[i]))
check:
toHex(data1) == LengthVectors[i]
toHex(data2) == LengthVectors[i]
check:
getLengthDecodedValue(data1) == LengthValues[i]
getLengthDecodedValue(data2) == LengthValues[i]
test "[length] mixing many values with same field number test":
for i in 0 ..< len(LengthValues):
var pb1 = initProtoBuffer()
var pb2 = initProtoBuffer()
for k in 0 ..< len(LengthValues):
let index = (i + k + 1) mod len(LengthValues)
pb1.write(1, LengthValues[index])
pb2.write(1, cast[seq[byte]](LengthValues[index]))
pb1.finish()
pb2.finish()
check getLengthDecodedValue(pb1.buffer) == LengthValues[i]
check getLengthDecodedValue(pb2.buffer) == LengthValues[i]
test "[length] incorrect values test":
for i in 0 ..< len(LengthValues):
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
test "[length] non-existent field test":
for i in 0 ..< len(LengthValues):
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
var pb = initProtoBuffer(data)
check:
pb.getField(2, value, valueLen) == false
valueLen == 0
test "[length] corrupted header test":
for i in 0 ..< len(LengthValues):
for k in 0 ..< 3:
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
test "[length] empty buffer test":
var value = newSeq[byte](len(LengthValues[0]))
var valueLen = 0
var pb = initProtoBuffer()
check:
pb.getField(1, value, valueLen) == false
valueLen == 0
test "[length] buffer overflow test":
for i in 1 ..< len(LengthValues):
let data = getLengthEncodedValue(LengthValues[i])
var value = newString(len(LengthValues[i]) - 1)
var valueLen = 0
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
valueLen == len(LengthValues[i])
isFullZero(value) == true
test "[length] mix of buffer overflow and normal fields test":
var pb1 = initProtoBuffer()
pb1.write(1, "TEST10")
pb1.write(1, "TEST20")
pb1.write(1, "TEST")
pb1.write(1, "TEST30")
pb1.write(1, "SOME")
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(4)
var valueLen = 0
check:
pb2.getField(1, value, valueLen) == true
value == "SOME"
test "[length] too big message test":
var pb1 = initProtoBuffer()
var bigString = newString(MaxMessageSize + 1)
for i in 0 ..< len(bigString):
bigString[i] = 'A'
pb1.write(1, bigString)
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(MaxMessageSize + 1)
var valueLen = 0
check:
pb2.getField(1, value, valueLen) == false
test "[length] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, "TEST1")
pb1.write(1, "TEST2")
pb1.write(2, "TEST5")
pb1.write(1, "TEST3")
pb1.write(1, "TEST4")
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[seq[byte]]
var fieldarr2: seq[seq[byte]]
var fieldarr3: seq[seq[byte]]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[string](fieldarr1[0]) == "TEST1"
cast[string](fieldarr1[1]) == "TEST2"
cast[string](fieldarr1[2]) == "TEST3"
cast[string](fieldarr1[3]) == "TEST4"
cast[string](fieldarr2[0]) == "TEST5"
test "Different value types in one message with same field number test":
proc getEncodedValue(): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, VarintValues[1])
pb.write(2, cast[float32](Fixed32Values[1]))
pb.write(3, cast[float64](Fixed64Values[1]))
pb.write(4, LengthValues[1])
pb.write(1, VarintValues[2])
pb.write(2, cast[float32](Fixed32Values[2]))
pb.write(3, cast[float64](Fixed64Values[2]))
pb.write(4, LengthValues[2])
pb.write(1, cast[float32](Fixed32Values[3]))
pb.write(2, cast[float64](Fixed64Values[3]))
pb.write(3, LengthValues[3])
pb.write(4, VarintValues[3])
pb.write(1, cast[float64](Fixed64Values[4]))
pb.write(2, LengthValues[4])
pb.write(3, VarintValues[4])
pb.write(4, cast[float32](Fixed32Values[4]))
pb.write(1, VarintValues[1])
pb.write(2, cast[float32](Fixed32Values[1]))
pb.write(3, cast[float64](Fixed64Values[1]))
pb.write(4, LengthValues[1])
pb.finish()
pb.buffer
let msg = getEncodedValue()
let pb = initProtoBuffer(msg)
var varintValue: uint64
var fixed32Value: float32
var fixed64Value: float64
var lengthValue = newString(10)
var lengthSize: int
check:
pb.getField(1, varintValue) == true
pb.getField(2, fixed32Value) == true
pb.getField(3, fixed64Value) == true
pb.getField(4, lengthValue, lengthSize) == true
lengthValue.setLen(lengthSize)
check:
varintValue == VarintValues[1]
cast[uint32](fixed32Value) == Fixed32Values[1]
cast[uint64](fixed64Value) == Fixed64Values[1]
lengthValue == LengthValues[1]

View File

@ -1,4 +1,5 @@
import testvarint,
testminprotobuf,
teststreamseq
import testrsa,