[WIP] Minprotobuf refactoring (#259)
* Minprotobuf initial commit * Fix noise. * Add signed integers support. Add checks for field number value. Remove some casts. * Fix compile errors. * Fix comments and constants.
This commit is contained in:
parent
181cf73ca7
commit
efb952f18b
|
@ -222,8 +222,8 @@ proc toBytes*(key: PrivateKey, data: var openarray[byte]): CryptoResult[int] =
|
|||
##
|
||||
## Returns number of bytes (octets) needed to store private key ``key``.
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
var blen = len(msg.buffer)
|
||||
if len(data) >= blen:
|
||||
|
@ -236,8 +236,8 @@ proc toBytes*(key: PublicKey, data: var openarray[byte]): CryptoResult[int] =
|
|||
##
|
||||
## Returns number of bytes (octets) needed to store public key ``key``.
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
var blen = len(msg.buffer)
|
||||
if len(data) >= blen and blen > 0:
|
||||
|
@ -256,8 +256,8 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] =
|
|||
## Return private key ``key`` in binary form (using libp2p's protobuf
|
||||
## serialization).
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
ok(msg.buffer)
|
||||
|
||||
|
@ -265,8 +265,8 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] =
|
|||
## Return public key ``key`` in binary form (using libp2p's protobuf
|
||||
## serialization).
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
ok(msg.buffer)
|
||||
|
||||
|
@ -283,9 +283,8 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool =
|
|||
var buffer: seq[byte]
|
||||
if len(data) > 0:
|
||||
var pb = initProtoBuffer(@data)
|
||||
if pb.getVarintValue(1, id) != 0:
|
||||
if pb.getBytes(2, buffer) != 0:
|
||||
if cast[int8](id) in SupportedSchemesInt:
|
||||
if pb.getField(1, id) and pb.getField(2, buffer):
|
||||
if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0:
|
||||
var scheme = cast[PKScheme](cast[int8](id))
|
||||
when key is PrivateKey:
|
||||
var nkey = PrivateKey(scheme: scheme)
|
||||
|
@ -727,11 +726,11 @@ proc createProposal*(nonce, pubkey: openarray[byte],
|
|||
## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and
|
||||
## comma-delimeted list of supported hashes ``hashes``.
|
||||
var msg = initProtoBuffer({WithUint32BeLength})
|
||||
msg.write(initProtoField(1, nonce))
|
||||
msg.write(initProtoField(2, pubkey))
|
||||
msg.write(initProtoField(3, exchanges))
|
||||
msg.write(initProtoField(4, ciphers))
|
||||
msg.write(initProtoField(5, hashes))
|
||||
msg.write(1, nonce)
|
||||
msg.write(2, pubkey)
|
||||
msg.write(3, exchanges)
|
||||
msg.write(4, ciphers)
|
||||
msg.write(5, hashes)
|
||||
msg.finish()
|
||||
shallowCopy(result, msg.buffer)
|
||||
|
||||
|
@ -744,19 +743,16 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
|
|||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
if pb.getLengthValue(1, nonce) != -1 and
|
||||
pb.getLengthValue(2, pubkey) != -1 and
|
||||
pb.getLengthValue(3, exchanges) != -1 and
|
||||
pb.getLengthValue(4, ciphers) != -1 and
|
||||
pb.getLengthValue(5, hashes) != -1:
|
||||
result = true
|
||||
pb.getField(1, nonce) and pb.getField(2, pubkey) and
|
||||
pb.getField(3, exchanges) and pb.getField(4, ciphers) and
|
||||
pb.getField(5, hashes)
|
||||
|
||||
proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] =
|
||||
## Create SecIO exchange message using ephemeral public key ``epubkey`` and
|
||||
## signature of proposal blocks ``signature``.
|
||||
var msg = initProtoBuffer({WithUint32BeLength})
|
||||
msg.write(initProtoField(1, epubkey))
|
||||
msg.write(initProtoField(2, signature))
|
||||
msg.write(1, epubkey)
|
||||
msg.write(2, signature)
|
||||
msg.finish()
|
||||
shallowCopy(result, msg.buffer)
|
||||
|
||||
|
@ -767,9 +763,7 @@ proc decodeExchange*(message: seq[byte],
|
|||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
if pb.getLengthValue(1, pubkey) != -1 and
|
||||
pb.getLengthValue(2, signature) != -1:
|
||||
result = true
|
||||
pb.getField(1, pubkey) and pb.getField(2, signature)
|
||||
|
||||
## Serialization/Deserialization helpers
|
||||
|
||||
|
@ -788,22 +782,27 @@ proc write*(vb: var VBuffer, sig: PrivateKey) {.
|
|||
## Write Signature value ``sig`` to buffer ``vb``.
|
||||
vb.writeSeq(sig.getBytes().tryGet())
|
||||
|
||||
proc initProtoField*(index: int, pubkey: PublicKey): ProtoField {.
|
||||
raises: [Defect, ResultError[CryptoError]].} =
|
||||
## Initialize ProtoField with PublicKey ``pubkey``.
|
||||
result = initProtoField(index, pubkey.getBytes().tryGet())
|
||||
proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int,
|
||||
key: T) {.
|
||||
inline, raises: [Defect, ResultError[CryptoError]].} =
|
||||
write(pb, field, key.getBytes().tryGet())
|
||||
|
||||
proc initProtoField*(index: int, seckey: PrivateKey): ProtoField {.
|
||||
raises: [Defect, ResultError[CryptoError]].} =
|
||||
## Initialize ProtoField with PrivateKey ``seckey``.
|
||||
result = initProtoField(index, seckey.getBytes().tryGet())
|
||||
proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {.
|
||||
inline, raises: [Defect, ResultError[CryptoError]].} =
|
||||
write(pb, field, sig.getBytes())
|
||||
|
||||
proc initProtoField*(index: int, sig: Signature): ProtoField =
|
||||
proc initProtoField*(index: int, key: PublicKey|PrivateKey): ProtoField {.
|
||||
deprecated, raises: [Defect, ResultError[CryptoError]].} =
|
||||
## Initialize ProtoField with PublicKey/PrivateKey ``key``.
|
||||
result = initProtoField(index, key.getBytes().tryGet())
|
||||
|
||||
proc initProtoField*(index: int, sig: Signature): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with Signature ``sig``.
|
||||
result = initProtoField(index, sig.getBytes())
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
|
||||
## Read ``PublicKey`` from ProtoBuf's message and validate it.
|
||||
proc getValue*[T: PublicKey|PrivateKey](data: var ProtoBuffer, field: int,
|
||||
value: var T): int {.deprecated.} =
|
||||
## Read PublicKey/PrivateKey from ProtoBuf's message and validate it.
|
||||
var buf: seq[byte]
|
||||
var key: PublicKey
|
||||
result = getLengthValue(data, field, buf)
|
||||
|
@ -813,18 +812,8 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
|
|||
else:
|
||||
value = key
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PrivateKey): int =
|
||||
## Read ``PrivateKey`` from ProtoBuf's message and validate it.
|
||||
var buf: seq[byte]
|
||||
var key: PrivateKey
|
||||
result = getLengthValue(data, field, buf)
|
||||
if result > 0:
|
||||
if not key.init(buf):
|
||||
result = -1
|
||||
else:
|
||||
value = key
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {.
|
||||
deprecated.} =
|
||||
## Read ``Signature`` from ProtoBuf's message and validate it.
|
||||
var buf: seq[byte]
|
||||
var sig: Signature
|
||||
|
@ -834,3 +823,30 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
|
|||
result = -1
|
||||
else:
|
||||
value = sig
|
||||
|
||||
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
|
||||
value: var T): bool =
|
||||
var buffer: seq[byte]
|
||||
var key: T
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if key.init(buffer):
|
||||
value = key
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool =
|
||||
var buffer: seq[byte]
|
||||
var sig: Signature
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if sig.init(buffer):
|
||||
value = sig
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
import nativesockets
|
||||
import tables, strutils, stew/shims/net
|
||||
import chronos
|
||||
import multicodec, multihash, multibase, transcoder, vbuffer, peerid
|
||||
import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
|
||||
protobuf/minprotobuf
|
||||
import stew/[base58, base32, endians2, results]
|
||||
export results
|
||||
export results, minprotobuf, vbuffer
|
||||
|
||||
type
|
||||
MAKind* = enum
|
||||
|
@ -477,7 +478,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] =
|
|||
else:
|
||||
ok($(proto.mcodec))
|
||||
|
||||
proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] =
|
||||
proc protoArgument*(ma: MultiAddress,
|
||||
value: var openarray[byte]): MaResult[int] =
|
||||
## Returns MultiAddress ``ma`` protocol argument value.
|
||||
##
|
||||
## If current MultiAddress do not have argument value, then result will be
|
||||
|
@ -580,7 +582,8 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] =
|
|||
|
||||
let proto = CodeAddresses.getOrDefault(MultiCodec(header))
|
||||
if proto.kind == None:
|
||||
yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'")
|
||||
yield err(MaResult[MultiAddress], "Unsupported protocol '" &
|
||||
$header & "'")
|
||||
|
||||
elif proto.kind == Fixed:
|
||||
data.setLen(proto.size)
|
||||
|
@ -609,7 +612,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} =
|
|||
return ok(true)
|
||||
ok(false)
|
||||
|
||||
proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
|
||||
proc `[]`*(ma: MultiAddress,
|
||||
codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
|
||||
## Returns partial MultiAddress with MultiCodec ``codec`` and present in
|
||||
## MultiAddress ``ma``.
|
||||
for item in ma.items:
|
||||
|
@ -634,7 +638,8 @@ proc toString*(value: MultiAddress): MaResult[string] =
|
|||
return err("multiaddress: Unsupported protocol '" & $header & "'")
|
||||
if proto.kind in {Fixed, Length, Path}:
|
||||
if isNil(proto.coder.bufferToString):
|
||||
return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder")
|
||||
return err("multiaddress: Missing protocol '" & $(proto.mcodec) &
|
||||
"' coder")
|
||||
if not proto.coder.bufferToString(vb.data, part):
|
||||
return err("multiaddress: Decoding protocol error")
|
||||
parts.add($(proto.mcodec))
|
||||
|
@ -729,12 +734,14 @@ proc init*(
|
|||
of None:
|
||||
raiseAssert "None checked above"
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerID): MaResult[MultiAddress] {.inline.} =
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
|
||||
value: PeerID): MaResult[MultiAddress] {.inline.} =
|
||||
## Initialize MultiAddress object from protocol id ``protocol`` and peer id
|
||||
## ``value``.
|
||||
init(mtype, protocol, value.data)
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int): MaResult[MultiAddress] =
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
|
||||
value: int): MaResult[MultiAddress] =
|
||||
## Initialize MultiAddress object from protocol id ``protocol`` and integer
|
||||
## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``,
|
||||
## ``dccp`` and ``sctp`` MultiAddresses.
|
||||
|
@ -759,7 +766,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} =
|
|||
if mc != InvalidMultiCodec:
|
||||
result = CodeAddresses.getOrDefault(mc)
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] =
|
||||
proc init*(mtype: typedesc[MultiAddress],
|
||||
value: string): MaResult[MultiAddress] =
|
||||
## Initialize MultiAddress object from string representation ``value``.
|
||||
var parts = value.trimRight('/').split('/')
|
||||
if len(parts[0]) != 0:
|
||||
|
@ -776,7 +784,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
|
|||
else:
|
||||
if proto.kind in {Fixed, Length, Path}:
|
||||
if isNil(proto.coder.stringToBuffer):
|
||||
return err("multiaddress: Missing protocol '" & part & "' transcoder")
|
||||
return err("multiaddress: Missing protocol '" &
|
||||
part & "' transcoder")
|
||||
|
||||
if offset + 1 >= len(parts):
|
||||
return err("multiaddress: Missing protocol '" & part & "' argument")
|
||||
|
@ -785,14 +794,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
|
|||
res.data.write(proto.mcodec)
|
||||
let res = proto.coder.stringToBuffer(parts[offset + 1], res.data)
|
||||
if not res:
|
||||
return err("multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`")
|
||||
return err("multiaddress: Error encoding `" & part & "/" &
|
||||
parts[offset + 1] & "`")
|
||||
offset += 2
|
||||
|
||||
elif proto.kind == Path:
|
||||
var path = "/" & (parts[(offset + 1)..^1].join("/"))
|
||||
res.data.write(proto.mcodec)
|
||||
if not proto.coder.stringToBuffer(path, res.data):
|
||||
return err("multiaddress: Error encoding `" & part & "/" & path & "`")
|
||||
return err("multiaddress: Error encoding `" & part & "/" &
|
||||
path & "`")
|
||||
|
||||
break
|
||||
elif proto.kind == Marker:
|
||||
|
@ -801,8 +812,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
|
|||
res.data.finish()
|
||||
ok(res)
|
||||
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], data: openarray[byte]): MaResult[MultiAddress] =
|
||||
proc init*(mtype: typedesc[MultiAddress],
|
||||
data: openarray[byte]): MaResult[MultiAddress] =
|
||||
## Initialize MultiAddress with array of bytes ``data``.
|
||||
if len(data) == 0:
|
||||
err("multiaddress: Address could not be empty!")
|
||||
|
@ -836,10 +847,12 @@ proc init*(mtype: typedesc[MultiAddress],
|
|||
var data = initVBuffer()
|
||||
data.write(familyProto.mcodec)
|
||||
var written = familyProto.coder.stringToBuffer($address, data)
|
||||
doAssert written, "Merely writing a string to a buffer should always be possible"
|
||||
doAssert written,
|
||||
"Merely writing a string to a buffer should always be possible"
|
||||
data.write(protoProto.mcodec)
|
||||
written = protoProto.coder.stringToBuffer($port, data)
|
||||
doAssert written, "Merely writing a string to a buffer should always be possible"
|
||||
doAssert written,
|
||||
"Merely writing a string to a buffer should always be possible"
|
||||
data.finish()
|
||||
|
||||
MultiAddress(data: data)
|
||||
|
@ -890,14 +903,16 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] =
|
|||
else:
|
||||
ok()
|
||||
|
||||
proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [Defect, ResultError[string]].} =
|
||||
proc `&`*(m1, m2: MultiAddress): MultiAddress {.
|
||||
raises: [Defect, ResultError[string]].} =
|
||||
## Concatenates two addresses ``m1`` and ``m2``, and returns result.
|
||||
##
|
||||
## This procedure performs validation of concatenated result and can raise
|
||||
## exception on error.
|
||||
concat(m1, m2).tryGet()
|
||||
|
||||
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [Defect, ResultError[string]].} =
|
||||
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.
|
||||
raises: [Defect, ResultError[string]].} =
|
||||
## Concatenates two addresses ``m1`` and ``m2``.
|
||||
##
|
||||
## This procedure performs validation of concatenated result and can raise
|
||||
|
@ -1005,3 +1020,36 @@ proc `$`*(pat: MaPattern): string =
|
|||
result = "(" & sub.join("|") & ")"
|
||||
elif pat.operator == Eq:
|
||||
result = $pat.value
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
|
||||
write(pb, field, value.data.buffer)
|
||||
|
||||
proc getField*(pb: var ProtoBuffer, field: int,
|
||||
value: var MultiAddress): bool {.inline.} =
|
||||
var buffer: seq[byte]
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
let ma = MultiAddress.init(buffer)
|
||||
if ma.isOk():
|
||||
value = ma.get()
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getRepeatedField*(pb: var ProtoBuffer, field: int,
|
||||
value: var seq[MultiAddress]): bool {.inline.} =
|
||||
var items: seq[seq[byte]]
|
||||
value.setLen(0)
|
||||
if not(getRepeatedField(pb, field, items)):
|
||||
return false
|
||||
if len(items) == 0:
|
||||
return true
|
||||
for item in items:
|
||||
let ma = MultiAddress.init(item)
|
||||
if ma.isOk():
|
||||
value.add(ma.get())
|
||||
else:
|
||||
value.setLen(0)
|
||||
return false
|
||||
|
|
|
@ -200,11 +200,12 @@ proc write*(vb: var VBuffer, pid: PeerID) {.inline.} =
|
|||
## Write PeerID value ``peerid`` to buffer ``vb``.
|
||||
vb.writeSeq(pid.data)
|
||||
|
||||
proc initProtoField*(index: int, pid: PeerID): ProtoField =
|
||||
proc initProtoField*(index: int, pid: PeerID): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with PeerID ``value``.
|
||||
result = initProtoField(index, pid.data)
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int {.
|
||||
deprecated.} =
|
||||
## Read ``PeerID`` from ProtoBuf's message and validate it.
|
||||
var pid: PeerID
|
||||
result = getLengthValue(data, field, pid.data)
|
||||
|
@ -213,3 +214,21 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
|
|||
result = -1
|
||||
else:
|
||||
value = pid
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) =
|
||||
## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
|
||||
write(pb, field, pid.data)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool =
|
||||
## Read ``PeerID`` from ProtoBuf's message and validate it
|
||||
var buffer: seq[byte]
|
||||
var peerId: PeerID
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if peerId.init(buffer):
|
||||
pid = peerId
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import ../varint
|
||||
import ../varint, stew/endians2
|
||||
|
||||
const
|
||||
MaxMessageSize* = 1'u shl 22
|
||||
|
@ -32,10 +32,14 @@ type
|
|||
offset*: int
|
||||
length*: int
|
||||
|
||||
ProtoHeader* = object
|
||||
wire*: ProtoFieldKind
|
||||
index*: uint64
|
||||
|
||||
ProtoField* = object
|
||||
## Protobuf's message field representation object
|
||||
index: int
|
||||
case kind: ProtoFieldKind
|
||||
index*: int
|
||||
case kind*: ProtoFieldKind
|
||||
of Varint:
|
||||
vint*: uint64
|
||||
of Fixed64:
|
||||
|
@ -47,13 +51,35 @@ type
|
|||
of StartGroup, EndGroup:
|
||||
discard
|
||||
|
||||
template protoHeader*(index: int, wire: ProtoFieldKind): uint =
|
||||
## Get protobuf's field header integer for ``index`` and ``wire``.
|
||||
((uint(index) shl 3) or cast[uint](wire))
|
||||
ProtoResult {.pure.} = enum
|
||||
VarintDecodeError,
|
||||
MessageIncompleteError,
|
||||
BufferOverflowError,
|
||||
MessageSizeTooBigError,
|
||||
NoError
|
||||
|
||||
template protoHeader*(field: ProtoField): uint =
|
||||
ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
|
||||
hint | hint32 | hint64 | float32 | float64
|
||||
|
||||
const
|
||||
SupportedWireTypes* = {
|
||||
int(ProtoFieldKind.Varint),
|
||||
int(ProtoFieldKind.Fixed64),
|
||||
int(ProtoFieldKind.Length),
|
||||
int(ProtoFieldKind.Fixed32)
|
||||
}
|
||||
|
||||
template checkFieldNumber*(i: int) =
|
||||
doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999),
|
||||
"Incorrect or reserved field number")
|
||||
|
||||
template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 =
|
||||
## Get protobuf's field header integer for ``index`` and ``wire``.
|
||||
((uint64(index) shl 3) or uint64(wire))
|
||||
|
||||
template getProtoHeader*(field: ProtoField): uint64 =
|
||||
## Get protobuf's field header integer for ``field``.
|
||||
((uint(field.index) shl 3) or cast[uint](field.kind))
|
||||
((uint64(field.index) shl 3) or uint64(field.kind))
|
||||
|
||||
template toOpenArray*(pb: ProtoBuffer): untyped =
|
||||
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)
|
||||
|
@ -72,20 +98,20 @@ template getLen*(pb: ProtoBuffer): int =
|
|||
|
||||
proc vsizeof*(field: ProtoField): int {.inline.} =
|
||||
## Returns number of bytes required to store protobuf's field ``field``.
|
||||
result = vsizeof(protoHeader(field))
|
||||
case field.kind
|
||||
of ProtoFieldKind.Varint:
|
||||
result += vsizeof(field.vint)
|
||||
vsizeof(getProtoHeader(field)) + vsizeof(field.vint)
|
||||
of ProtoFieldKind.Fixed64:
|
||||
result += sizeof(field.vfloat64)
|
||||
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat64)
|
||||
of ProtoFieldKind.Fixed32:
|
||||
result += sizeof(field.vfloat32)
|
||||
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32)
|
||||
of ProtoFieldKind.Length:
|
||||
result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer)
|
||||
vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) +
|
||||
len(field.vbuffer)
|
||||
else:
|
||||
discard
|
||||
0
|
||||
|
||||
proc initProtoField*(index: int, value: SomeVarint): ProtoField =
|
||||
proc initProtoField*(index: int, value: SomeVarint): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with integer value.
|
||||
result = ProtoField(kind: Varint, index: index)
|
||||
when type(value) is uint64:
|
||||
|
@ -93,26 +119,28 @@ proc initProtoField*(index: int, value: SomeVarint): ProtoField =
|
|||
else:
|
||||
result.vint = cast[uint64](value)
|
||||
|
||||
proc initProtoField*(index: int, value: bool): ProtoField =
|
||||
proc initProtoField*(index: int, value: bool): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with integer value.
|
||||
result = ProtoField(kind: Varint, index: index)
|
||||
result.vint = byte(value)
|
||||
|
||||
proc initProtoField*(index: int, value: openarray[byte]): ProtoField =
|
||||
proc initProtoField*(index: int,
|
||||
value: openarray[byte]): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with bytes array.
|
||||
result = ProtoField(kind: Length, index: index)
|
||||
if len(value) > 0:
|
||||
result.vbuffer = newSeq[byte](len(value))
|
||||
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
|
||||
|
||||
proc initProtoField*(index: int, value: string): ProtoField =
|
||||
proc initProtoField*(index: int, value: string): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with string.
|
||||
result = ProtoField(kind: Length, index: index)
|
||||
if len(value) > 0:
|
||||
result.vbuffer = newSeq[byte](len(value))
|
||||
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
|
||||
|
||||
proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} =
|
||||
proc initProtoField*(index: int,
|
||||
value: ProtoBuffer): ProtoField {.deprecated, inline.} =
|
||||
## Initialize ProtoField with nested message stored in ``value``.
|
||||
##
|
||||
## Note: This procedure performs shallow copy of ``value`` sequence.
|
||||
|
@ -127,6 +155,13 @@ proc initProtoBuffer*(data: seq[byte], offset = 0,
|
|||
result.offset = offset
|
||||
result.options = options
|
||||
|
||||
proc initProtoBuffer*(data: openarray[byte], offset = 0,
|
||||
options: set[ProtoFlags] = {}): ProtoBuffer =
|
||||
## Initialize ProtoBuffer with copy of ``data``.
|
||||
result.buffer = @data
|
||||
result.offset = offset
|
||||
result.options = options
|
||||
|
||||
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
|
||||
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
|
||||
result.buffer = newSeqOfCap[byte](128)
|
||||
|
@ -138,16 +173,134 @@ proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
|
|||
result.offset = 10
|
||||
elif {WithUint32LeLength, WithUint32BeLength} * options != {}:
|
||||
# Our buffer will start from position 4, so we can store length of buffer
|
||||
# in [0, 9].
|
||||
# in [0, 3].
|
||||
result.buffer.setLen(4)
|
||||
result.offset = 4
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: ProtoField) =
|
||||
proc write*[T: ProtoScalar](pb: var ProtoBuffer,
|
||||
field: int, value: T) =
|
||||
checkFieldNumber(field)
|
||||
var length = 0
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) +
|
||||
vsizeof(value)
|
||||
let header = ProtoFieldKind.Varint
|
||||
elif T is float32:
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) +
|
||||
sizeof(T)
|
||||
let header = ProtoFieldKind.Fixed32
|
||||
elif T is float64:
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) +
|
||||
sizeof(T)
|
||||
let header = ProtoFieldKind.Fixed64
|
||||
|
||||
pb.buffer.setLen(len(pb.buffer) + flength)
|
||||
|
||||
let hres = PB.putUVarint(pb.toOpenArray(), length,
|
||||
getProtoHeader(field, header))
|
||||
doAssert(hres.isOk())
|
||||
pb.offset += length
|
||||
when (T is uint64) or (T is uint32) or (T is uint):
|
||||
let vres = PB.putUVarint(pb.toOpenArray(), length, value)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif (T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
let vres = putSVarint(pb.toOpenArray(), length, value)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif T is float32:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u32 = cast[uint32](value)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
elif T is float64:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u64 = cast[uint64](value)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
|
||||
proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int,
|
||||
value: openarray[T]) =
|
||||
checkFieldNumber(field)
|
||||
var length = 0
|
||||
let dlength =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
var res = 0
|
||||
for item in value:
|
||||
res += vsizeof(item)
|
||||
res
|
||||
elif (T is float32) or (T is float64):
|
||||
len(value) * sizeof(T)
|
||||
|
||||
let header = getProtoHeader(field, ProtoFieldKind.Length)
|
||||
let flength = vsizeof(header) + vsizeof(uint64(dlength)) + dlength
|
||||
pb.buffer.setLen(len(pb.buffer) + flength)
|
||||
let hres = PB.putUVarint(pb.toOpenArray(), length, header)
|
||||
doAssert(hres.isOk())
|
||||
pb.offset += length
|
||||
length = 0
|
||||
let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(dlength))
|
||||
doAssert(lres.isOk())
|
||||
pb.offset += length
|
||||
for item in value:
|
||||
when (T is uint64) or (T is uint32) or (T is uint):
|
||||
length = 0
|
||||
let vres = PB.putUVarint(pb.toOpenArray(), length, item)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif (T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
length = 0
|
||||
let vres = PB.putSVarint(pb.toOpenArray(), length, item)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif T is float32:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u32 = cast[uint32](item)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
elif T is float64:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u64 = cast[uint64](item)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
|
||||
proc write*[T: byte|char](pb: var ProtoBuffer, field: int,
|
||||
value: openarray[T]) =
|
||||
checkFieldNumber(field)
|
||||
var length = 0
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) +
|
||||
vsizeof(uint64(len(value))) + len(value)
|
||||
pb.buffer.setLen(len(pb.buffer) + flength)
|
||||
let hres = PB.putUVarint(pb.toOpenArray(), length,
|
||||
getProtoHeader(field, ProtoFieldKind.Length))
|
||||
doAssert(hres.isOk())
|
||||
pb.offset += length
|
||||
let lres = PB.putUVarint(pb.toOpenArray(), length,
|
||||
uint64(len(value)))
|
||||
doAssert(lres.isOk())
|
||||
pb.offset += length
|
||||
if len(value) > 0:
|
||||
doAssert(pb.isEnough(len(value)))
|
||||
copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value))
|
||||
pb.offset += len(value)
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: ProtoBuffer) {.inline.} =
|
||||
## Encode Protobuf's sub-message ``value`` and store it to protobuf's buffer
|
||||
## ``pb`` with field number ``field``.
|
||||
write(pb, field, value.buffer)
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: ProtoField) {.deprecated.} =
|
||||
## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``.
|
||||
var length = 0
|
||||
var res: VarintResult[void]
|
||||
pb.buffer.setLen(len(pb.buffer) + vsizeof(field))
|
||||
res = PB.putUVarint(pb.toOpenArray(), length, protoHeader(field))
|
||||
res = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field))
|
||||
doAssert(res.isOk())
|
||||
pb.offset += length
|
||||
case field.kind
|
||||
|
@ -199,31 +352,440 @@ proc finish*(pb: var ProtoBuffer) =
|
|||
pb.offset = pos
|
||||
elif WithUint32BeLength in pb.options:
|
||||
let size = uint(len(pb.buffer) - 4)
|
||||
pb.buffer[0] = byte((size shr 24) and 0xFF'u)
|
||||
pb.buffer[1] = byte((size shr 16) and 0xFF'u)
|
||||
pb.buffer[2] = byte((size shr 8) and 0xFF'u)
|
||||
pb.buffer[3] = byte(size and 0xFF'u)
|
||||
pb.buffer[0 ..< 4] = toBytesBE(uint32(size))
|
||||
pb.offset = 4
|
||||
elif WithUint32LeLength in pb.options:
|
||||
let size = uint(len(pb.buffer) - 4)
|
||||
pb.buffer[0] = byte(size and 0xFF'u)
|
||||
pb.buffer[1] = byte((size shr 8) and 0xFF'u)
|
||||
pb.buffer[2] = byte((size shr 16) and 0xFF'u)
|
||||
pb.buffer[3] = byte((size shr 24) and 0xFF'u)
|
||||
pb.buffer[0 ..< 4] = toBytesLE(uint32(size))
|
||||
pb.offset = 4
|
||||
else:
|
||||
pb.offset = 0
|
||||
|
||||
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
|
||||
var length = 0
|
||||
var hdr = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
|
||||
let index = uint64(hdr shr 3)
|
||||
let wire = hdr and 0x07
|
||||
if wire in SupportedWireTypes:
|
||||
data.offset += length
|
||||
header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire))
|
||||
true
|
||||
else:
|
||||
false
|
||||
else:
|
||||
false
|
||||
|
||||
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
|
||||
case header.wire
|
||||
of ProtoFieldKind.Varint:
|
||||
var length = 0
|
||||
var value = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
true
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.Fixed32:
|
||||
if data.isEnough(sizeof(uint32)):
|
||||
data.offset += sizeof(uint32)
|
||||
true
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.Fixed64:
|
||||
if data.isEnough(sizeof(uint64)):
|
||||
data.offset += sizeof(uint64)
|
||||
true
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.Length:
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||
data.offset += length
|
||||
if bsize <= uint64(MaxMessageSize):
|
||||
if data.isEnough(int(bsize)):
|
||||
data.offset += int(bsize)
|
||||
true
|
||||
else:
|
||||
false
|
||||
else:
|
||||
false
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
|
||||
false
|
||||
|
||||
proc getValue[T: ProtoScalar](data: var ProtoBuffer,
|
||||
header: ProtoHeader,
|
||||
outval: var T): ProtoResult =
|
||||
when (T is uint64) or (T is uint32) or (T is uint):
|
||||
doAssert(header.wire == ProtoFieldKind.Varint)
|
||||
var length = 0
|
||||
var value = T(0)
|
||||
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
outval = value
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
elif (T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
doAssert(header.wire == ProtoFieldKind.Varint)
|
||||
var length = 0
|
||||
var value = T(0)
|
||||
if getSVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
outval = value
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
elif T is float32:
|
||||
doAssert(header.wire == ProtoFieldKind.Fixed32)
|
||||
if data.isEnough(sizeof(float32)):
|
||||
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
|
||||
data.offset += sizeof(float32)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
elif T is float64:
|
||||
doAssert(header.wire == ProtoFieldKind.Fixed64)
|
||||
if data.isEnough(sizeof(float64)):
|
||||
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
|
||||
data.offset += sizeof(float64)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
|
||||
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
|
||||
outBytes: var openarray[T],
|
||||
outLength: var int): ProtoResult =
|
||||
doAssert(header.wire == ProtoFieldKind.Length)
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
|
||||
outLength = 0
|
||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||
data.offset += length
|
||||
if bsize <= uint64(MaxMessageSize):
|
||||
if data.isEnough(int(bsize)):
|
||||
outLength = int(bsize)
|
||||
if len(outBytes) >= int(bsize):
|
||||
if bsize > 0'u64:
|
||||
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
# Buffer overflow should not be critical failure
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.BufferOverflowError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
else:
|
||||
ProtoResult.MessageSizeTooBigError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
|
||||
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
|
||||
outBytes: var T): ProtoResult =
|
||||
doAssert(header.wire == ProtoFieldKind.Length)
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
outBytes.setLen(0)
|
||||
|
||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||
data.offset += length
|
||||
if bsize <= uint64(MaxMessageSize):
|
||||
if data.isEnough(int(bsize)):
|
||||
outBytes.setLen(bsize)
|
||||
if bsize > 0'u64:
|
||||
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
else:
|
||||
ProtoResult.MessageSizeTooBigError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
|
||||
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
||||
output: var T): bool =
|
||||
checkFieldNumber(field)
|
||||
var value: T
|
||||
var res = false
|
||||
var pb = data
|
||||
output = T(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output = T(0)
|
||||
return false
|
||||
let wireCheck =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
header.wire == ProtoFieldKind.Varint
|
||||
elif T is float32:
|
||||
header.wire == ProtoFieldKind.Fixed32
|
||||
elif T is float64:
|
||||
header.wire == ProtoFieldKind.Fixed64
|
||||
if header.index == uint64(field):
|
||||
if wireCheck:
|
||||
let r = getValue(pb, header, value)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
res = true
|
||||
output = value
|
||||
else:
|
||||
return false
|
||||
else:
|
||||
# We are ignoring wire types different from what we expect, because it
|
||||
# is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
output = T(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output = T(0)
|
||||
return false
|
||||
res
|
||||
|
||||
proc getField*[T: byte|char](data: ProtoBuffer, field: int,
|
||||
output: var openarray[T],
|
||||
outlen: var int): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
var res = false
|
||||
|
||||
outlen = 0
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
let r = getValue(pb, header, output, outlen)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
res = true
|
||||
of ProtoResult.BufferOverflowError:
|
||||
# Buffer overflow error is not critical error, we still can get
|
||||
# field values with proper size.
|
||||
discard
|
||||
else:
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
return false
|
||||
else:
|
||||
# We are ignoring wire types different from ProtoFieldKind.Length,
|
||||
# because it is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
|
||||
res
|
||||
|
||||
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
|
||||
output: var T): bool =
|
||||
checkFieldNumber(field)
|
||||
var res = false
|
||||
var pb = data
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
let r = getValue(pb, header, output)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
res = true
|
||||
of ProtoResult.BufferOverflowError:
|
||||
# Buffer overflow error is not critical error, we still can get
|
||||
# field values with proper size.
|
||||
discard
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
# We are ignoring wire types different from ProtoFieldKind.Length,
|
||||
# because it is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
res
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {.
|
||||
inline.} =
|
||||
var buffer: seq[byte]
|
||||
if pb.getField(field, buffer):
|
||||
output = initProtoBuffer(buffer)
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
|
||||
output: var seq[T]): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
var item: T
|
||||
let r = getValue(pb, header, item)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer,
|
||||
field: int,
|
||||
output: var seq[T]): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
|
||||
ProtoFieldKind.Fixed64}:
|
||||
var item: T
|
||||
let r = getValue(pb, header, item)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
||||
output: var seq[T]): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
var arritem: seq[byte]
|
||||
let rarr = getValue(pb, header, arritem)
|
||||
case rarr
|
||||
of ProtoResult.NoError:
|
||||
var pbarr = initProtoBuffer(arritem)
|
||||
let itemHeader =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
ProtoHeader(wire: ProtoFieldKind.Varint)
|
||||
elif T is float32:
|
||||
ProtoHeader(wire: ProtoFieldKind.Fixed32)
|
||||
elif T is float64:
|
||||
ProtoHeader(wire: ProtoFieldKind.Fixed64)
|
||||
while not(pbarr.isEmpty()):
|
||||
var item: T
|
||||
let res = getValue(pbarr, itemHeader, item)
|
||||
case res
|
||||
of ProtoResult.NoError:
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getVarintValue*(data: var ProtoBuffer, field: int,
|
||||
value: var SomeVarint): int =
|
||||
value: var SomeVarint): int {.deprecated.} =
|
||||
## Get value of `Varint` type.
|
||||
var length = 0
|
||||
var header = 0'u64
|
||||
var soffset = data.offset
|
||||
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
|
||||
length, header).isOk():
|
||||
data.offset += length
|
||||
if header == protoHeader(field, Varint):
|
||||
if header == getProtoHeader(field, Varint):
|
||||
if not data.isEmpty():
|
||||
when type(value) is int32 or type(value) is int64 or type(value) is int:
|
||||
let res = getSVarint(data.toOpenArray(), length, value)
|
||||
|
@ -237,7 +799,7 @@ proc getVarintValue*(data: var ProtoBuffer, field: int,
|
|||
data.offset = soffset
|
||||
|
||||
proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
|
||||
buffer: var T): int =
|
||||
buffer: var T): int {.deprecated.} =
|
||||
## Get value of `Length` type.
|
||||
var length = 0
|
||||
var header = 0'u64
|
||||
|
@ -245,10 +807,12 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
|
|||
var soffset = data.offset
|
||||
result = -1
|
||||
buffer.setLen(0)
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
|
||||
length, header).isOk():
|
||||
data.offset += length
|
||||
if header == protoHeader(field, Length):
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, ssize).isOk():
|
||||
if header == getProtoHeader(field, Length):
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
|
||||
length, ssize).isOk():
|
||||
data.offset += length
|
||||
if ssize <= MaxMessageSize and data.isEnough(int(ssize)):
|
||||
buffer.setLen(ssize)
|
||||
|
@ -262,16 +826,16 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
|
|||
data.offset = soffset
|
||||
|
||||
proc getBytes*(data: var ProtoBuffer, field: int,
|
||||
buffer: var seq[byte]): int {.inline.} =
|
||||
buffer: var seq[byte]): int {.deprecated, inline.} =
|
||||
## Get value of `Length` type as bytes.
|
||||
result = getLengthValue(data, field, buffer)
|
||||
|
||||
proc getString*(data: var ProtoBuffer, field: int,
|
||||
buffer: var string): int {.inline.} =
|
||||
buffer: var string): int {.deprecated, inline.} =
|
||||
## Get value of `Length` type as string.
|
||||
result = getLengthValue(data, field, buffer)
|
||||
|
||||
proc enterSubmessage*(pb: var ProtoBuffer): int =
|
||||
proc enterSubmessage*(pb: var ProtoBuffer): int {.deprecated.} =
|
||||
## Processes protobuf's sub-message and adjust internal offset to enter
|
||||
## inside of sub-message. Returns field index of sub-message field or
|
||||
## ``0`` on error.
|
||||
|
@ -280,10 +844,12 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
|
|||
var msize = 0'u64
|
||||
var soffset = pb.offset
|
||||
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, header).isOk():
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
|
||||
length, header).isOk():
|
||||
pb.offset += length
|
||||
if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length):
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, msize).isOk():
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
|
||||
length, msize).isOk():
|
||||
pb.offset += length
|
||||
if msize <= MaxMessageSize and pb.isEnough(int(msize)):
|
||||
pb.length = int(msize)
|
||||
|
@ -292,7 +858,7 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
|
|||
# Restore offset on error
|
||||
pb.offset = soffset
|
||||
|
||||
proc skipSubmessage*(pb: var ProtoBuffer) =
|
||||
proc skipSubmessage*(pb: var ProtoBuffer) {.deprecated.} =
|
||||
## Skip current protobuf's sub-message and adjust internal offset to the
|
||||
## end of sub-message.
|
||||
doAssert(pb.length != 0)
|
||||
|
|
|
@ -47,61 +47,49 @@ type
|
|||
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
|
||||
result.write(initProtoField(1, peerInfo.publicKey.get().getBytes().tryGet()))
|
||||
result.write(1, peerInfo.publicKey.get().getBytes().tryGet())
|
||||
|
||||
for ma in peerInfo.addrs:
|
||||
result.write(initProtoField(2, ma.data.buffer))
|
||||
result.write(2, ma.data.buffer)
|
||||
|
||||
for proto in peerInfo.protocols:
|
||||
result.write(initProtoField(3, proto))
|
||||
result.write(3, proto)
|
||||
|
||||
result.write(initProtoField(4, observedAddr.data.buffer))
|
||||
result.write(4, observedAddr.data.buffer)
|
||||
|
||||
let protoVersion = ProtoVersion
|
||||
result.write(initProtoField(5, protoVersion))
|
||||
result.write(5, protoVersion)
|
||||
|
||||
let agentVersion = AgentVersion
|
||||
result.write(initProtoField(6, agentVersion))
|
||||
result.write(6, agentVersion)
|
||||
result.finish()
|
||||
|
||||
proc decodeMsg*(buf: seq[byte]): IdentifyInfo =
|
||||
var pb = initProtoBuffer(buf)
|
||||
|
||||
result.pubKey = none(PublicKey)
|
||||
var pubKey: PublicKey
|
||||
if pb.getValue(1, pubKey) > 0:
|
||||
if pb.getField(1, pubKey):
|
||||
trace "read public key from message", pubKey = ($pubKey).shortLog
|
||||
result.pubKey = some(pubKey)
|
||||
|
||||
result.addrs = newSeq[MultiAddress]()
|
||||
var address = newSeq[byte]()
|
||||
while pb.getBytes(2, address) > 0:
|
||||
if len(address) != 0:
|
||||
var copyaddr = address
|
||||
var ma = MultiAddress.init(copyaddr).tryGet()
|
||||
result.addrs.add(ma)
|
||||
trace "read address bytes from message", address = ma
|
||||
address.setLen(0)
|
||||
if pb.getRepeatedField(2, result.addrs):
|
||||
trace "read addresses from message", addresses = result.addrs
|
||||
|
||||
var proto = ""
|
||||
while pb.getString(3, proto) > 0:
|
||||
trace "read proto from message", proto = proto
|
||||
result.protos.add(proto)
|
||||
proto = ""
|
||||
if pb.getRepeatedField(3, result.protos):
|
||||
trace "read protos from message", protocols = result.protos
|
||||
|
||||
var observableAddr = newSeq[byte]()
|
||||
if pb.getBytes(4, observableAddr) > 0: # attempt to read the observed addr
|
||||
var ma = MultiAddress.init(observableAddr).tryGet()
|
||||
trace "read observedAddr from message", address = ma
|
||||
result.observedAddr = some(ma)
|
||||
var observableAddr: MultiAddress
|
||||
if pb.getField(4, observableAddr):
|
||||
trace "read observableAddr from message", address = observableAddr
|
||||
result.observedAddr = some(observableAddr)
|
||||
|
||||
var protoVersion = ""
|
||||
if pb.getString(5, protoVersion) > 0:
|
||||
if pb.getField(5, protoVersion):
|
||||
trace "read protoVersion from message", protoVersion = protoVersion
|
||||
result.protoVersion = some(protoVersion)
|
||||
|
||||
var agentVersion = ""
|
||||
if pb.getString(6, agentVersion) > 0:
|
||||
if pb.getField(6, agentVersion):
|
||||
trace "read agentVersion from message", agentVersion = agentVersion
|
||||
result.agentVersion = some(agentVersion)
|
||||
|
||||
|
|
|
@ -32,9 +32,7 @@ func defaultMsgIdProvider*(m: Message): string =
|
|||
byteutils.toHex(m.seqno) & m.fromPeer.pretty
|
||||
|
||||
proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} =
|
||||
var buff = initProtoBuffer()
|
||||
encodeMessage(msg, buff)
|
||||
p.privateKey.sign(PubSubPrefix & buff.buffer).tryGet().getBytes()
|
||||
p.privateKey.sign(PubSubPrefix & encodeMessage(msg)).tryGet().getBytes()
|
||||
|
||||
proc verify*(m: Message, p: PeerInfo): bool =
|
||||
if m.signature.len > 0 and m.key.len > 0:
|
||||
|
@ -42,14 +40,11 @@ proc verify*(m: Message, p: PeerInfo): bool =
|
|||
msg.signature = @[]
|
||||
msg.key = @[]
|
||||
|
||||
var buff = initProtoBuffer()
|
||||
encodeMessage(msg, buff)
|
||||
|
||||
var remote: Signature
|
||||
var key: PublicKey
|
||||
if remote.init(m.signature) and key.init(m.key):
|
||||
trace "verifying signature", remoteSignature = remote
|
||||
result = remote.verify(PubSubPrefix & buff.buffer, key)
|
||||
result = remote.verify(PubSubPrefix & encodeMessage(msg), key)
|
||||
|
||||
if result:
|
||||
libp2p_pubsub_sig_verify_success.inc()
|
||||
|
|
|
@ -14,265 +14,247 @@ import messages,
|
|||
../../../utility,
|
||||
../../../protobuf/minprotobuf
|
||||
|
||||
proc encodeGraft*(graft: ControlGraft, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, graft.topicID))
|
||||
proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, graft.topicID)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} =
|
||||
trace "decoding graft msg", buffer = pb.buffer.shortLog
|
||||
while true:
|
||||
var topic: string
|
||||
if pb.getString(1, topic) < 0:
|
||||
break
|
||||
proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, prune.topicID)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
trace "read topic field from graft msg", topicID = topic
|
||||
result.add(ControlGraft(topicID: topic))
|
||||
|
||||
proc encodePrune*(prune: ControlPrune, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, prune.topicID))
|
||||
|
||||
proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} =
|
||||
trace "decoding prune msg"
|
||||
while true:
|
||||
var topic: string
|
||||
if pb.getString(1, topic) < 0:
|
||||
break
|
||||
|
||||
trace "read topic field from prune msg", topicID = topic
|
||||
result.add(ControlPrune(topicID: topic))
|
||||
|
||||
proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, ihave.topicID))
|
||||
proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, ihave.topicID)
|
||||
for mid in ihave.messageIDs:
|
||||
pb.write(initProtoField(2, mid))
|
||||
ipb.write(2, mid)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc decodeIHave*(pb: var ProtoBuffer): seq[ControlIHave] {.gcsafe.} =
|
||||
trace "decoding ihave msg"
|
||||
|
||||
while true:
|
||||
var control: ControlIHave
|
||||
if pb.getString(1, control.topicID) < 0:
|
||||
trace "topic field missing from ihave msg"
|
||||
break
|
||||
|
||||
trace "read topic field", topicID = control.topicID
|
||||
|
||||
while true:
|
||||
var mid: string
|
||||
if pb.getString(2, mid) < 0:
|
||||
break
|
||||
trace "read messageID field", mid = mid
|
||||
control.messageIDs.add(mid)
|
||||
|
||||
result.add(control)
|
||||
|
||||
proc encodeIWant*(iwant: ControlIWant, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) =
|
||||
var ipb = initProtoBuffer()
|
||||
for mid in iwant.messageIDs:
|
||||
pb.write(initProtoField(1, mid))
|
||||
ipb.write(1, mid)
|
||||
if len(ipb.buffer) > 0:
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc decodeIWant*(pb: var ProtoBuffer): seq[ControlIWant] {.gcsafe.} =
|
||||
trace "decoding iwant msg"
|
||||
proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) =
|
||||
var ipb = initProtoBuffer()
|
||||
for ihave in control.ihave:
|
||||
ipb.write(1, ihave)
|
||||
for iwant in control.iwant:
|
||||
ipb.write(2, iwant)
|
||||
for graft in control.graft:
|
||||
ipb.write(3, graft)
|
||||
for prune in control.prune:
|
||||
ipb.write(4, prune)
|
||||
if len(ipb.buffer) > 0:
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
var control: ControlIWant
|
||||
while true:
|
||||
var mid: string
|
||||
if pb.getString(1, mid) < 0:
|
||||
break
|
||||
control.messageIDs.add(mid)
|
||||
trace "read messageID field", mid = mid
|
||||
result.add(control)
|
||||
|
||||
proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
if control.ihave.len > 0:
|
||||
var ihave = initProtoBuffer()
|
||||
for h in control.ihave:
|
||||
h.encodeIHave(ihave)
|
||||
|
||||
# write messages to protobuf
|
||||
if ihave.buffer.len > 0:
|
||||
ihave.finish()
|
||||
pb.write(initProtoField(1, ihave))
|
||||
|
||||
if control.iwant.len > 0:
|
||||
var iwant = initProtoBuffer()
|
||||
for w in control.iwant:
|
||||
w.encodeIWant(iwant)
|
||||
|
||||
# write messages to protobuf
|
||||
if iwant.buffer.len > 0:
|
||||
iwant.finish()
|
||||
pb.write(initProtoField(2, iwant))
|
||||
|
||||
if control.graft.len > 0:
|
||||
var graft = initProtoBuffer()
|
||||
for g in control.graft:
|
||||
g.encodeGraft(graft)
|
||||
|
||||
# write messages to protobuf
|
||||
if graft.buffer.len > 0:
|
||||
graft.finish()
|
||||
pb.write(initProtoField(3, graft))
|
||||
|
||||
if control.prune.len > 0:
|
||||
var prune = initProtoBuffer()
|
||||
for p in control.prune:
|
||||
p.encodePrune(prune)
|
||||
|
||||
# write messages to protobuf
|
||||
if prune.buffer.len > 0:
|
||||
prune.finish()
|
||||
pb.write(initProtoField(4, prune))
|
||||
|
||||
proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} =
|
||||
trace "decoding control submessage"
|
||||
var control: ControlMessage
|
||||
while true:
|
||||
var field = pb.enterSubMessage()
|
||||
trace "processing submessage", field = field
|
||||
case field:
|
||||
of 0:
|
||||
trace "no submessage found in Control msg"
|
||||
break
|
||||
of 1:
|
||||
control.ihave &= pb.decodeIHave()
|
||||
of 2:
|
||||
control.iwant &= pb.decodeIWant()
|
||||
of 3:
|
||||
control.graft &= pb.decodeGraft()
|
||||
of 4:
|
||||
control.prune &= pb.decodePrune()
|
||||
else:
|
||||
raise newException(CatchableError, "message type not recognized")
|
||||
|
||||
if result.isNone:
|
||||
result = some(control)
|
||||
|
||||
proc encodeSubs*(subs: SubOpts, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, subs.subscribe))
|
||||
pb.write(initProtoField(2, subs.topic))
|
||||
|
||||
proc decodeSubs*(pb: var ProtoBuffer): seq[SubOpts] {.gcsafe.} =
|
||||
while true:
|
||||
var subOpt: SubOpts
|
||||
var subscr: uint
|
||||
discard pb.getVarintValue(1, subscr)
|
||||
subOpt.subscribe = cast[bool](subscr)
|
||||
trace "read subscribe field", subscribe = subOpt.subscribe
|
||||
|
||||
if pb.getString(2, subOpt.topic) < 0:
|
||||
break
|
||||
trace "read subscribe field", topicName = subOpt.topic
|
||||
|
||||
result.add(subOpt)
|
||||
|
||||
trace "got subscriptions", subscriptions = result
|
||||
|
||||
proc encodeMessage*(msg: Message, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, msg.fromPeer.getBytes()))
|
||||
pb.write(initProtoField(2, msg.data))
|
||||
pb.write(initProtoField(3, msg.seqno))
|
||||
|
||||
for t in msg.topicIDs:
|
||||
pb.write(initProtoField(4, t))
|
||||
|
||||
if msg.signature.len > 0:
|
||||
pb.write(initProtoField(5, msg.signature))
|
||||
|
||||
if msg.key.len > 0:
|
||||
pb.write(initProtoField(6, msg.key))
|
||||
proc write*(pb: var ProtoBuffer, field: int, subs: SubOpts) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, uint64(subs.subscribe))
|
||||
ipb.write(2, subs.topic)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc encodeMessage*(msg: Message): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, msg.fromPeer)
|
||||
pb.write(2, msg.data)
|
||||
pb.write(3, msg.seqno)
|
||||
for topic in msg.topicIDs:
|
||||
pb.write(4, topic)
|
||||
if len(msg.signature) > 0:
|
||||
pb.write(5, msg.signature)
|
||||
if len(msg.key) > 0:
|
||||
pb.write(6, msg.key)
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
proc decodeMessages*(pb: var ProtoBuffer): seq[Message] {.gcsafe.} =
|
||||
# TODO: which of this fields are really optional?
|
||||
while true:
|
||||
var msg: Message
|
||||
var fromPeer: seq[byte]
|
||||
if pb.getBytes(1, fromPeer) < 0:
|
||||
break
|
||||
try:
|
||||
msg.fromPeer = PeerID.init(fromPeer).tryGet()
|
||||
except CatchableError as err:
|
||||
debug "Invalid fromPeer in message", msg = err.msg
|
||||
break
|
||||
proc write*(pb: var ProtoBuffer, field: int, msg: Message) =
|
||||
pb.write(field, encodeMessage(msg))
|
||||
|
||||
trace "read message field", fromPeer = msg.fromPeer.pretty
|
||||
|
||||
if pb.getBytes(2, msg.data) < 0:
|
||||
break
|
||||
trace "read message field", data = msg.data.shortLog
|
||||
|
||||
if pb.getBytes(3, msg.seqno) < 0:
|
||||
break
|
||||
trace "read message field", seqno = msg.seqno.shortLog
|
||||
|
||||
var topic: string
|
||||
while true:
|
||||
if pb.getString(4, topic) < 0:
|
||||
break
|
||||
msg.topicIDs.add(topic)
|
||||
trace "read message field", topicName = topic
|
||||
topic = ""
|
||||
|
||||
discard pb.getBytes(5, msg.signature)
|
||||
trace "read message field", signature = msg.signature.shortLog
|
||||
|
||||
discard pb.getBytes(6, msg.key)
|
||||
trace "read message field", key = msg.key.shortLog
|
||||
|
||||
result.add(msg)
|
||||
|
||||
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} =
|
||||
result = initProtoBuffer()
|
||||
trace "encoding msg: ", msg = msg.shortLog
|
||||
|
||||
if msg.subscriptions.len > 0:
|
||||
for s in msg.subscriptions:
|
||||
var subs = initProtoBuffer()
|
||||
encodeSubs(s, subs)
|
||||
|
||||
# write subscriptions to protobuf
|
||||
if subs.buffer.len > 0:
|
||||
subs.finish()
|
||||
result.write(initProtoField(1, subs))
|
||||
|
||||
if msg.messages.len > 0:
|
||||
var messages = initProtoBuffer()
|
||||
for m in msg.messages:
|
||||
encodeMessage(m, messages)
|
||||
|
||||
# write messages to protobuf
|
||||
if messages.buffer.len > 0:
|
||||
messages.finish()
|
||||
result.write(initProtoField(2, messages))
|
||||
|
||||
if msg.control.isSome:
|
||||
var control = initProtoBuffer()
|
||||
msg.control.get.encodeControl(control)
|
||||
|
||||
# write messages to protobuf
|
||||
if control.buffer.len > 0:
|
||||
control.finish()
|
||||
result.write(initProtoField(3, control))
|
||||
|
||||
if result.buffer.len > 0:
|
||||
result.finish()
|
||||
|
||||
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} =
|
||||
var pb = initProtoBuffer(msg)
|
||||
|
||||
while true:
|
||||
# decode SubOpts array
|
||||
var field = pb.enterSubMessage()
|
||||
trace "processing submessage", field = field
|
||||
case field:
|
||||
of 0:
|
||||
trace "no submessage found in RPC msg"
|
||||
break
|
||||
of 1:
|
||||
result.subscriptions &= pb.decodeSubs()
|
||||
of 2:
|
||||
result.messages &= pb.decodeMessages()
|
||||
of 3:
|
||||
result.control = pb.decodeControl()
|
||||
proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} =
|
||||
trace "decodeGraft: decoding message"
|
||||
var control = ControlGraft()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodeGraft: read topicId", topic_id = topicId
|
||||
else:
|
||||
raise newException(CatchableError, "message type not recognized")
|
||||
trace "decodeGraft: topicId is missing"
|
||||
control
|
||||
|
||||
proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} =
|
||||
trace "decodePrune: decoding message"
|
||||
var control = ControlPrune()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodePrune: read topicId", topic_id = topicId
|
||||
else:
|
||||
trace "decodePrune: topicId is missing"
|
||||
control
|
||||
|
||||
proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} =
|
||||
trace "decodeIHave: decoding message"
|
||||
var control = ControlIHave()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodeIHave: read topicId", topic_id = topicId
|
||||
else:
|
||||
trace "decodeIHave: topicId is missing"
|
||||
if pb.getRepeatedField(2, control.messageIDs):
|
||||
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIHave: no messageIDs"
|
||||
control
|
||||
|
||||
proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} =
|
||||
trace "decodeIWant: decoding message"
|
||||
var control = ControlIWant()
|
||||
if pb.getRepeatedField(1, control.messageIDs):
|
||||
trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIWant: no messageIDs"
|
||||
|
||||
proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} =
|
||||
trace "decodeControl: decoding message"
|
||||
var buffer: seq[byte]
|
||||
if pb.getField(3, buffer):
|
||||
var control: ControlMessage
|
||||
var cpb = initProtoBuffer(buffer)
|
||||
var ihavepbs: seq[seq[byte]]
|
||||
var iwantpbs: seq[seq[byte]]
|
||||
var graftpbs: seq[seq[byte]]
|
||||
var prunepbs: seq[seq[byte]]
|
||||
|
||||
discard cpb.getRepeatedField(1, ihavepbs)
|
||||
discard cpb.getRepeatedField(2, iwantpbs)
|
||||
discard cpb.getRepeatedField(3, graftpbs)
|
||||
discard cpb.getRepeatedField(4, prunepbs)
|
||||
|
||||
for item in ihavepbs:
|
||||
control.ihave.add(decodeIHave(initProtoBuffer(item)))
|
||||
for item in iwantpbs:
|
||||
control.iwant.add(decodeIWant(initProtoBuffer(item)))
|
||||
for item in graftpbs:
|
||||
control.graft.add(decodeGraft(initProtoBuffer(item)))
|
||||
for item in prunepbs:
|
||||
control.prune.add(decodePrune(initProtoBuffer(item)))
|
||||
|
||||
trace "decodeControl: "
|
||||
some(control)
|
||||
else:
|
||||
none[ControlMessage]()
|
||||
|
||||
proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} =
|
||||
trace "decodeSubscription: decoding message"
|
||||
var subflag: uint64
|
||||
var sub = SubOpts()
|
||||
if pb.getField(1, subflag):
|
||||
sub.subscribe = bool(subflag)
|
||||
trace "decodeSubscription: read subscribe", subscribe = subflag
|
||||
else:
|
||||
trace "decodeSubscription: subscribe is missing"
|
||||
if pb.getField(2, sub.topic):
|
||||
trace "decodeSubscription: read topic", topic = sub.topic
|
||||
else:
|
||||
trace "decodeSubscription: topic is missing"
|
||||
|
||||
sub
|
||||
|
||||
proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} =
|
||||
trace "decodeSubscriptions: decoding message"
|
||||
var subpbs: seq[seq[byte]]
|
||||
var subs: seq[SubOpts]
|
||||
if pb.getRepeatedField(1, subpbs):
|
||||
trace "decodeSubscriptions: read subscriptions", count = len(subpbs)
|
||||
for item in subpbs:
|
||||
let sub = decodeSubscription(initProtoBuffer(item))
|
||||
subs.add(sub)
|
||||
|
||||
if len(subs) == 0:
|
||||
trace "decodeSubscription: no subscriptions found"
|
||||
|
||||
subs
|
||||
|
||||
proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} =
|
||||
trace "decodeMessage: decoding message"
|
||||
var msg: Message
|
||||
if pb.getField(1, msg.fromPeer):
|
||||
trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty()
|
||||
else:
|
||||
trace "decodeMessage: fromPeer is missing"
|
||||
|
||||
if pb.getField(2, msg.data):
|
||||
trace "decodeMessage: read data", data = msg.data.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: data is missing"
|
||||
|
||||
if pb.getField(3, msg.seqno):
|
||||
trace "decodeMessage: read seqno", seqno = msg.data.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: seqno is missing"
|
||||
|
||||
if pb.getRepeatedField(4, msg.topicIDs):
|
||||
trace "decodeMessage: read topics", topic_ids = msg.topicIDs
|
||||
else:
|
||||
trace "decodeMessage: topics are missing"
|
||||
|
||||
if pb.getField(5, msg.signature):
|
||||
trace "decodeMessage: read signature", signature = msg.signature.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: signature is missing"
|
||||
|
||||
if pb.getField(6, msg.key):
|
||||
trace "decodeMessage: read public key", key = msg.key.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: public key is missing"
|
||||
|
||||
msg
|
||||
|
||||
proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} =
|
||||
trace "decodeMessages: decoding message"
|
||||
var msgpbs: seq[seq[byte]]
|
||||
var msgs: seq[Message]
|
||||
if pb.getRepeatedField(2, msgpbs):
|
||||
trace "decodeMessages: read messages", count = len(msgpbs)
|
||||
for item in msgpbs:
|
||||
let msg = decodeMessage(initProtoBuffer(item))
|
||||
msgs.add(msg)
|
||||
|
||||
if len(msgs) == 0:
|
||||
trace "decodeMessages: no messages found"
|
||||
|
||||
msgs
|
||||
|
||||
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
|
||||
trace "encodeRpcMsg: encoding message", msg = msg.shortLog()
|
||||
var pb = initProtoBuffer()
|
||||
for item in msg.subscriptions:
|
||||
pb.write(1, item)
|
||||
for item in msg.messages:
|
||||
pb.write(2, item)
|
||||
if msg.control.isSome():
|
||||
pb.write(3, msg.control.get())
|
||||
if len(pb.buffer) > 0:
|
||||
pb.finish()
|
||||
result = pb
|
||||
|
||||
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg =
|
||||
trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
|
||||
var pb = initProtoBuffer(msg)
|
||||
var rpcMsg: RPCMsg
|
||||
rpcMsg.messages = pb.decodeMessages()
|
||||
rpcMsg.subscriptions = pb.decodeSubscriptions()
|
||||
rpcMsg.control = pb.decodeControl()
|
||||
|
||||
rpcMsg
|
||||
|
|
|
@ -430,8 +430,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
|
||||
var
|
||||
libp2pProof = initProtoBuffer()
|
||||
libp2pProof.write(initProtoField(1, p.localPublicKey))
|
||||
libp2pProof.write(initProtoField(2, signedPayload.getBytes()))
|
||||
libp2pProof.write(1, p.localPublicKey)
|
||||
libp2pProof.write(2, signedPayload.getBytes())
|
||||
# data field also there but not used!
|
||||
libp2pProof.finish()
|
||||
|
||||
|
@ -449,9 +449,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
remoteSig: Signature
|
||||
remoteSigBytes: seq[byte]
|
||||
|
||||
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0:
|
||||
if not(remoteProof.getField(1, remotePubKeyBytes)):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0:
|
||||
if not(remoteProof.getField(2, remoteSigBytes)):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
|
||||
if not remotePubKey.init(remotePubKeyBytes):
|
||||
|
|
|
@ -0,0 +1,677 @@
|
|||
## Nim-Libp2p
|
||||
## Copyright (c) 2018 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import unittest
|
||||
import ../libp2p/protobuf/minprotobuf
|
||||
import stew/byteutils, strutils
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
|
||||
suite "MinProtobuf test suite":
|
||||
const VarintVectors = [
|
||||
"0800", "0801", "08ffffffff07", "08ffffffff0f", "08ffffffffffffffff7f",
|
||||
"08ffffffffffffffffff01"
|
||||
]
|
||||
|
||||
const VarintValues = [
|
||||
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xFFFF_FFFF'u64,
|
||||
0x7FFF_FFFF_FFFF_FFFF'u64, 0xFFFF_FFFF_FFFF_FFFF'u64
|
||||
]
|
||||
|
||||
const Fixed32Vectors = [
|
||||
"0d00000000", "0d01000000", "0dffffff7f", "0dddccbbaa", "0dffffffff"
|
||||
]
|
||||
|
||||
const Fixed32Values = [
|
||||
0x0'u32, 0x1'u32, 0x7FFF_FFFF'u32, 0xAABB_CCDD'u32, 0xFFFF_FFFF'u32
|
||||
]
|
||||
|
||||
const Fixed64Vectors = [
|
||||
"090000000000000000", "090100000000000000", "09ffffff7f00000000",
|
||||
"09ddccbbaa00000000", "09ffffffff00000000", "09ffffffffffffff7f",
|
||||
"099988ffeeddccbbaa", "09ffffffffffffffff"
|
||||
]
|
||||
|
||||
const Fixed64Values = [
|
||||
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xAABB_CCDD'u64, 0xFFFF_FFFF'u64,
|
||||
0x7FFF_FFFF_FFFF_FFFF'u64, 0xAABB_CCDD_EEFF_8899'u64,
|
||||
0xFFFF_FFFF_FFFF_FFFF'u64
|
||||
]
|
||||
|
||||
const LengthVectors = [
|
||||
"0a00", "0a0161", "0a026162", "0a0461626364", "0a086162636465666768"
|
||||
]
|
||||
|
||||
const LengthValues = [
|
||||
"", "a", "ab", "abcd", "abcdefgh"
|
||||
]
|
||||
|
||||
## This vector values was tested with `protoc` and related proto file.
|
||||
|
||||
## syntax = "proto2";
|
||||
## message testmsg {
|
||||
## repeated uint64 d = 1 [packed=true];
|
||||
## repeated uint64 d = 2 [packed=true];
|
||||
## }
|
||||
const PackedVarintVector =
|
||||
"0a1f0001ffffffff07ffffffff0fffffffffffffffff7fffffffffffffffffff0112020001"
|
||||
## syntax = "proto2";
|
||||
## message testmsg {
|
||||
## repeated sfixed32 d = 1 [packed=true];
|
||||
## repeated sfixed32 d = 2 [packed=true];
|
||||
## }
|
||||
const PackedFixed32Vector =
|
||||
"0a140000000001000000ffffff7fddccbbaaffffffff12080000000001000000"
|
||||
## syntax = "proto2";
|
||||
## message testmsg {
|
||||
## repeated sfixed64 d = 1 [packed=true];
|
||||
## repeated sfixed64 d = 2 [packed=true];
|
||||
## }
|
||||
const PackedFixed64Vector =
|
||||
"""0a4000000000000000000100000000000000ffffff7f00000000ddccbbaa00000000
|
||||
ffffffff00000000ffffffffffffff7f9988ffeeddccbbaaffffffffffffffff1210
|
||||
00000000000000000100000000000000"""
|
||||
|
||||
proc getVarintEncodedValue(value: uint64): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getVarintDecodedValue(data: openarray[byte]): uint64 =
|
||||
var value: uint64
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
value
|
||||
|
||||
proc getFixed32EncodedValue(value: float32): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getFixed32DecodedValue(data: openarray[byte]): uint32 =
|
||||
var value: float32
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
cast[uint32](value)
|
||||
|
||||
proc getFixed64EncodedValue(value: float64): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getFixed64DecodedValue(data: openarray[byte]): uint64 =
|
||||
var value: float64
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
cast[uint64](value)
|
||||
|
||||
proc getLengthEncodedValue(value: string): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getLengthEncodedValue(value: seq[byte]): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getLengthDecodedValue(data: openarray[byte]): string =
|
||||
var value = newString(len(data))
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value, valueLen)
|
||||
|
||||
doAssert(res)
|
||||
value.setLen(valueLen)
|
||||
value
|
||||
|
||||
proc isFullZero[T: byte|char](data: openarray[T]): bool =
|
||||
for ch in data:
|
||||
if int(ch) != 0:
|
||||
return false
|
||||
return true
|
||||
|
||||
proc corruptHeader(data: var openarray[byte], index: int) =
|
||||
var values = [3, 4, 6]
|
||||
data[0] = data[0] and 0xF8'u8
|
||||
data[0] = data[0] or byte(values[index mod len(values)])
|
||||
|
||||
test "[varint] edge values test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
let data = getVarintEncodedValue(VarintValues[i])
|
||||
check:
|
||||
toHex(data) == VarintVectors[i]
|
||||
getVarintDecodedValue(data) == VarintValues[i]
|
||||
|
||||
test "[varint] mixing many values with same field number test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
var pb = initProtoBuffer()
|
||||
for k in 0 ..< len(VarintValues):
|
||||
let index = (i + k + 1) mod len(VarintValues)
|
||||
pb.write(1, VarintValues[index])
|
||||
pb.finish()
|
||||
check getVarintDecodedValue(pb.buffer) == VarintValues[i]
|
||||
|
||||
test "[varint] incorrect values test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
var value: uint64
|
||||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[varint] non-existent field test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
var value: uint64
|
||||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == 0'u64
|
||||
|
||||
test "[varint] corrupted header test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
for k in 0 ..< 3:
|
||||
var value: uint64
|
||||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[varint] empty buffer test":
|
||||
var value: uint64
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == 0'u64
|
||||
|
||||
test "[varint] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, VarintValues[1])
|
||||
pb1.write(1, VarintValues[2])
|
||||
pb1.write(2, VarintValues[3])
|
||||
pb1.write(1, VarintValues[4])
|
||||
pb1.write(1, VarintValues[5])
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[uint64]
|
||||
var fieldarr2: seq[uint64]
|
||||
var fieldarr3: seq[uint64]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
fieldarr1[0] == VarintValues[1]
|
||||
fieldarr1[1] == VarintValues[2]
|
||||
fieldarr1[2] == VarintValues[4]
|
||||
fieldarr1[3] == VarintValues[5]
|
||||
fieldarr2[0] == VarintValues[3]
|
||||
|
||||
test "[varint] Repeated packed field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.writePacked(1, VarintValues)
|
||||
pb1.writePacked(2, VarintValues[0 .. 1])
|
||||
pb1.finish()
|
||||
check:
|
||||
toHex(pb1.buffer) == PackedVarintVector
|
||||
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[uint64]
|
||||
var fieldarr2: seq[uint64]
|
||||
var fieldarr3: seq[uint64]
|
||||
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 6
|
||||
fieldarr1[0] == VarintValues[0]
|
||||
fieldarr1[1] == VarintValues[1]
|
||||
fieldarr1[2] == VarintValues[2]
|
||||
fieldarr1[3] == VarintValues[3]
|
||||
fieldarr1[4] == VarintValues[4]
|
||||
fieldarr1[5] == VarintValues[5]
|
||||
fieldarr2[0] == VarintValues[0]
|
||||
fieldarr2[1] == VarintValues[1]
|
||||
|
||||
test "[fixed32] edge values test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
let data = getFixed32EncodedValue(cast[float32](Fixed32Values[i]))
|
||||
check:
|
||||
toHex(data) == Fixed32Vectors[i]
|
||||
getFixed32DecodedValue(data) == Fixed32Values[i]
|
||||
|
||||
test "[fixed32] mixing many values with same field number test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
var pb = initProtoBuffer()
|
||||
for k in 0 ..< len(Fixed32Values):
|
||||
let index = (i + k + 1) mod len(Fixed32Values)
|
||||
pb.write(1, cast[float32](Fixed32Values[index]))
|
||||
pb.finish()
|
||||
check getFixed32DecodedValue(pb.buffer) == Fixed32Values[i]
|
||||
|
||||
test "[fixed32] incorrect values test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
var value: float32
|
||||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed32] non-existent field test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
var value: float32
|
||||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == float32(0)
|
||||
|
||||
test "[fixed32] corrupted header test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
for k in 0 ..< 3:
|
||||
var value: float32
|
||||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed32] empty buffer test":
|
||||
var value: float32
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == float32(0)
|
||||
|
||||
test "[fixed32] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, cast[float32](Fixed32Values[0]))
|
||||
pb1.write(1, cast[float32](Fixed32Values[1]))
|
||||
pb1.write(2, cast[float32](Fixed32Values[2]))
|
||||
pb1.write(1, cast[float32](Fixed32Values[3]))
|
||||
pb1.write(1, cast[float32](Fixed32Values[4]))
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float32]
|
||||
var fieldarr2: seq[float32]
|
||||
var fieldarr3: seq[float32]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
cast[uint32](fieldarr1[0]) == Fixed64Values[0]
|
||||
cast[uint32](fieldarr1[1]) == Fixed64Values[1]
|
||||
cast[uint32](fieldarr1[2]) == Fixed64Values[3]
|
||||
cast[uint32](fieldarr1[3]) == Fixed64Values[4]
|
||||
cast[uint32](fieldarr2[0]) == Fixed64Values[2]
|
||||
|
||||
test "[fixed32] Repeated packed field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
var values = newSeq[float32](len(Fixed32Values))
|
||||
for i in 0 ..< len(values):
|
||||
values[i] = cast[float32](Fixed32Values[i])
|
||||
pb1.writePacked(1, values)
|
||||
pb1.writePacked(2, values[0 .. 1])
|
||||
pb1.finish()
|
||||
check:
|
||||
toHex(pb1.buffer) == PackedFixed32Vector
|
||||
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float32]
|
||||
var fieldarr2: seq[float32]
|
||||
var fieldarr3: seq[float32]
|
||||
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 5
|
||||
cast[uint32](fieldarr1[0]) == Fixed32Values[0]
|
||||
cast[uint32](fieldarr1[1]) == Fixed32Values[1]
|
||||
cast[uint32](fieldarr1[2]) == Fixed32Values[2]
|
||||
cast[uint32](fieldarr1[3]) == Fixed32Values[3]
|
||||
cast[uint32](fieldarr1[4]) == Fixed32Values[4]
|
||||
cast[uint32](fieldarr2[0]) == Fixed32Values[0]
|
||||
cast[uint32](fieldarr2[1]) == Fixed32Values[1]
|
||||
|
||||
test "[fixed64] edge values test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
let data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
check:
|
||||
toHex(data) == Fixed64Vectors[i]
|
||||
getFixed64DecodedValue(data) == Fixed64Values[i]
|
||||
|
||||
test "[fixed64] mixing many values with same field number test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
var pb = initProtoBuffer()
|
||||
for k in 0 ..< len(Fixed64Values):
|
||||
let index = (i + k + 1) mod len(Fixed64Values)
|
||||
pb.write(1, cast[float64](Fixed64Values[index]))
|
||||
pb.finish()
|
||||
check getFixed64DecodedValue(pb.buffer) == Fixed64Values[i]
|
||||
|
||||
test "[fixed64] incorrect values test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
var value: float32
|
||||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed64] non-existent field test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
var value: float64
|
||||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == float64(0)
|
||||
|
||||
test "[fixed64] corrupted header test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
for k in 0 ..< 3:
|
||||
var value: float64
|
||||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed64] empty buffer test":
|
||||
var value: float64
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == float64(0)
|
||||
|
||||
test "[fixed64] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, cast[float64](Fixed64Values[2]))
|
||||
pb1.write(1, cast[float64](Fixed64Values[3]))
|
||||
pb1.write(2, cast[float64](Fixed64Values[4]))
|
||||
pb1.write(1, cast[float64](Fixed64Values[5]))
|
||||
pb1.write(1, cast[float64](Fixed64Values[6]))
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float64]
|
||||
var fieldarr2: seq[float64]
|
||||
var fieldarr3: seq[float64]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
cast[uint64](fieldarr1[0]) == Fixed64Values[2]
|
||||
cast[uint64](fieldarr1[1]) == Fixed64Values[3]
|
||||
cast[uint64](fieldarr1[2]) == Fixed64Values[5]
|
||||
cast[uint64](fieldarr1[3]) == Fixed64Values[6]
|
||||
cast[uint64](fieldarr2[0]) == Fixed64Values[4]
|
||||
|
||||
test "[fixed64] Repeated packed field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
var values = newSeq[float64](len(Fixed64Values))
|
||||
for i in 0 ..< len(values):
|
||||
values[i] = cast[float64](Fixed64Values[i])
|
||||
pb1.writePacked(1, values)
|
||||
pb1.writePacked(2, values[0 .. 1])
|
||||
pb1.finish()
|
||||
let expect = PackedFixed64Vector.multiReplace(("\n", ""), (" ", ""))
|
||||
check:
|
||||
toHex(pb1.buffer) == expect
|
||||
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float64]
|
||||
var fieldarr2: seq[float64]
|
||||
var fieldarr3: seq[float64]
|
||||
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 8
|
||||
cast[uint64](fieldarr1[0]) == Fixed64Values[0]
|
||||
cast[uint64](fieldarr1[1]) == Fixed64Values[1]
|
||||
cast[uint64](fieldarr1[2]) == Fixed64Values[2]
|
||||
cast[uint64](fieldarr1[3]) == Fixed64Values[3]
|
||||
cast[uint64](fieldarr1[4]) == Fixed64Values[4]
|
||||
cast[uint64](fieldarr1[5]) == Fixed64Values[5]
|
||||
cast[uint64](fieldarr1[6]) == Fixed64Values[6]
|
||||
cast[uint64](fieldarr1[7]) == Fixed64Values[7]
|
||||
cast[uint64](fieldarr2[0]) == Fixed64Values[0]
|
||||
cast[uint64](fieldarr2[1]) == Fixed64Values[1]
|
||||
|
||||
test "[length] edge values test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
let data1 = getLengthEncodedValue(LengthValues[i])
|
||||
let data2 = getLengthEncodedValue(cast[seq[byte]](LengthValues[i]))
|
||||
check:
|
||||
toHex(data1) == LengthVectors[i]
|
||||
toHex(data2) == LengthVectors[i]
|
||||
check:
|
||||
getLengthDecodedValue(data1) == LengthValues[i]
|
||||
getLengthDecodedValue(data2) == LengthValues[i]
|
||||
|
||||
test "[length] mixing many values with same field number test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
var pb1 = initProtoBuffer()
|
||||
var pb2 = initProtoBuffer()
|
||||
for k in 0 ..< len(LengthValues):
|
||||
let index = (i + k + 1) mod len(LengthValues)
|
||||
pb1.write(1, LengthValues[index])
|
||||
pb2.write(1, cast[seq[byte]](LengthValues[index]))
|
||||
pb1.finish()
|
||||
pb2.finish()
|
||||
check getLengthDecodedValue(pb1.buffer) == LengthValues[i]
|
||||
check getLengthDecodedValue(pb2.buffer) == LengthValues[i]
|
||||
|
||||
test "[length] incorrect values test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
var value = newSeq[byte](len(LengthValues[i]))
|
||||
var valueLen = 0
|
||||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
|
||||
test "[length] non-existent field test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
var value = newSeq[byte](len(LengthValues[i]))
|
||||
var valueLen = 0
|
||||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value, valueLen) == false
|
||||
valueLen == 0
|
||||
|
||||
test "[length] corrupted header test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
for k in 0 ..< 3:
|
||||
var value = newSeq[byte](len(LengthValues[i]))
|
||||
var valueLen = 0
|
||||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
|
||||
test "[length] empty buffer test":
|
||||
var value = newSeq[byte](len(LengthValues[0]))
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
valueLen == 0
|
||||
|
||||
test "[length] buffer overflow test":
|
||||
for i in 1 ..< len(LengthValues):
|
||||
let data = getLengthEncodedValue(LengthValues[i])
|
||||
|
||||
var value = newString(len(LengthValues[i]) - 1)
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
valueLen == len(LengthValues[i])
|
||||
isFullZero(value) == true
|
||||
|
||||
test "[length] mix of buffer overflow and normal fields test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, "TEST10")
|
||||
pb1.write(1, "TEST20")
|
||||
pb1.write(1, "TEST")
|
||||
pb1.write(1, "TEST30")
|
||||
pb1.write(1, "SOME")
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var value = newString(4)
|
||||
var valueLen = 0
|
||||
check:
|
||||
pb2.getField(1, value, valueLen) == true
|
||||
value == "SOME"
|
||||
|
||||
test "[length] too big message test":
|
||||
var pb1 = initProtoBuffer()
|
||||
var bigString = newString(MaxMessageSize + 1)
|
||||
|
||||
for i in 0 ..< len(bigString):
|
||||
bigString[i] = 'A'
|
||||
pb1.write(1, bigString)
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var value = newString(MaxMessageSize + 1)
|
||||
var valueLen = 0
|
||||
check:
|
||||
pb2.getField(1, value, valueLen) == false
|
||||
|
||||
test "[length] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, "TEST1")
|
||||
pb1.write(1, "TEST2")
|
||||
pb1.write(2, "TEST5")
|
||||
pb1.write(1, "TEST3")
|
||||
pb1.write(1, "TEST4")
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[seq[byte]]
|
||||
var fieldarr2: seq[seq[byte]]
|
||||
var fieldarr3: seq[seq[byte]]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
cast[string](fieldarr1[0]) == "TEST1"
|
||||
cast[string](fieldarr1[1]) == "TEST2"
|
||||
cast[string](fieldarr1[2]) == "TEST3"
|
||||
cast[string](fieldarr1[3]) == "TEST4"
|
||||
cast[string](fieldarr2[0]) == "TEST5"
|
||||
|
||||
test "Different value types in one message with same field number test":
|
||||
proc getEncodedValue(): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, VarintValues[1])
|
||||
pb.write(2, cast[float32](Fixed32Values[1]))
|
||||
pb.write(3, cast[float64](Fixed64Values[1]))
|
||||
pb.write(4, LengthValues[1])
|
||||
|
||||
pb.write(1, VarintValues[2])
|
||||
pb.write(2, cast[float32](Fixed32Values[2]))
|
||||
pb.write(3, cast[float64](Fixed64Values[2]))
|
||||
pb.write(4, LengthValues[2])
|
||||
|
||||
pb.write(1, cast[float32](Fixed32Values[3]))
|
||||
pb.write(2, cast[float64](Fixed64Values[3]))
|
||||
pb.write(3, LengthValues[3])
|
||||
pb.write(4, VarintValues[3])
|
||||
|
||||
pb.write(1, cast[float64](Fixed64Values[4]))
|
||||
pb.write(2, LengthValues[4])
|
||||
pb.write(3, VarintValues[4])
|
||||
pb.write(4, cast[float32](Fixed32Values[4]))
|
||||
|
||||
pb.write(1, VarintValues[1])
|
||||
pb.write(2, cast[float32](Fixed32Values[1]))
|
||||
pb.write(3, cast[float64](Fixed64Values[1]))
|
||||
pb.write(4, LengthValues[1])
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
let msg = getEncodedValue()
|
||||
let pb = initProtoBuffer(msg)
|
||||
var varintValue: uint64
|
||||
var fixed32Value: float32
|
||||
var fixed64Value: float64
|
||||
var lengthValue = newString(10)
|
||||
var lengthSize: int
|
||||
|
||||
check:
|
||||
pb.getField(1, varintValue) == true
|
||||
pb.getField(2, fixed32Value) == true
|
||||
pb.getField(3, fixed64Value) == true
|
||||
pb.getField(4, lengthValue, lengthSize) == true
|
||||
|
||||
lengthValue.setLen(lengthSize)
|
||||
|
||||
check:
|
||||
varintValue == VarintValues[1]
|
||||
cast[uint32](fixed32Value) == Fixed32Values[1]
|
||||
cast[uint64](fixed64Value) == Fixed64Values[1]
|
||||
lengthValue == LengthValues[1]
|
|
@ -1,4 +1,5 @@
|
|||
import testvarint,
|
||||
testminprotobuf,
|
||||
teststreamseq
|
||||
|
||||
import testrsa,
|
||||
|
|
Loading…
Reference in New Issue