Minprotobuf refactoring 2 (#269)
* Protobuf refactoring stage II. * Remove NoError. * Change trace level for invalid message.
This commit is contained in:
parent
9eb5828a42
commit
b832668768
|
@ -31,11 +31,6 @@ type
|
|||
ECDSA,
|
||||
NoSupport
|
||||
|
||||
CipherScheme* = enum
|
||||
Aes128 = 0,
|
||||
Aes256,
|
||||
Blowfish
|
||||
|
||||
DigestSheme* = enum
|
||||
Sha256,
|
||||
Sha512
|
||||
|
@ -283,7 +278,9 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool =
|
|||
var buffer: seq[byte]
|
||||
if len(data) > 0:
|
||||
var pb = initProtoBuffer(@data)
|
||||
if pb.getField(1, id) and pb.getField(2, buffer):
|
||||
let r1 = pb.getField(1, id)
|
||||
let r2 = pb.getField(2, buffer)
|
||||
if r1.isOk() and r1.get() and r2.isOk() and r2.get():
|
||||
if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0:
|
||||
var scheme = cast[PKScheme](cast[int8](id))
|
||||
when key is PrivateKey:
|
||||
|
@ -743,9 +740,15 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
|
|||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
pb.getField(1, nonce) and pb.getField(2, pubkey) and
|
||||
pb.getField(3, exchanges) and pb.getField(4, ciphers) and
|
||||
pb.getField(5, hashes)
|
||||
let r1 = pb.getField(1, nonce)
|
||||
let r2 = pb.getField(2, pubkey)
|
||||
let r3 = pb.getField(3, exchanges)
|
||||
let r4 = pb.getField(4, ciphers)
|
||||
let r5 = pb.getField(5, hashes)
|
||||
|
||||
r1.isOk() and r1.get() and r2.isOk() and r2.get() and
|
||||
r3.isOk() and r3.get() and r4.isOk() and r4.get() and
|
||||
r5.isOk() and r5.get()
|
||||
|
||||
proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] =
|
||||
## Create SecIO exchange message using ephemeral public key ``epubkey`` and
|
||||
|
@ -763,7 +766,9 @@ proc decodeExchange*(message: seq[byte],
|
|||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
pb.getField(1, pubkey) and pb.getField(2, signature)
|
||||
let r1 = pb.getField(1, pubkey)
|
||||
let r2 = pb.getField(2, signature)
|
||||
r1.isOk() and r1.get() and r2.isOk() and r2.get()
|
||||
|
||||
## Serialization/Deserialization helpers
|
||||
|
||||
|
@ -825,28 +830,37 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {.
|
|||
value = sig
|
||||
|
||||
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
|
||||
value: var T): bool =
|
||||
value: var T): ProtoResult[bool] =
|
||||
## Deserialize public/private key from protobuf's message ``pb`` using field
|
||||
## index ``field``.
|
||||
##
|
||||
## On success deserialized key will be stored in ``value``.
|
||||
var buffer: seq[byte]
|
||||
var key: T
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if key.init(buffer):
|
||||
value = key
|
||||
true
|
||||
let res = ? pb.getField(field, buffer)
|
||||
if not(res):
|
||||
ok(false)
|
||||
else:
|
||||
false
|
||||
if key.init(buffer):
|
||||
value = key
|
||||
ok(true)
|
||||
else:
|
||||
err(ProtoError.IncorrectBlob)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool =
|
||||
proc getField*(pb: ProtoBuffer, field: int,
|
||||
value: var Signature): ProtoResult[bool] =
|
||||
## Deserialize signature from protobuf's message ``pb`` using field index
|
||||
## ``field``.
|
||||
##
|
||||
## On success deserialized signature will be stored in ``value``.
|
||||
var buffer: seq[byte]
|
||||
var sig: Signature
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if sig.init(buffer):
|
||||
value = sig
|
||||
true
|
||||
let res = ? pb.getField(field, buffer)
|
||||
if not(res):
|
||||
ok(false)
|
||||
else:
|
||||
false
|
||||
if sig.init(buffer):
|
||||
value = sig
|
||||
ok(true)
|
||||
else:
|
||||
err(ProtoError.IncorrectBlob)
|
||||
|
|
|
@ -1025,31 +1025,34 @@ proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
|
|||
write(pb, field, value.data.buffer)
|
||||
|
||||
proc getField*(pb: var ProtoBuffer, field: int,
|
||||
value: var MultiAddress): bool {.inline.} =
|
||||
value: var MultiAddress): ProtoResult[bool] {.
|
||||
inline.} =
|
||||
var buffer: seq[byte]
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
let ma = MultiAddress.init(buffer)
|
||||
if ma.isOk():
|
||||
value = ma.get()
|
||||
true
|
||||
let res = ? pb.getField(field, buffer)
|
||||
if not(res):
|
||||
ok(false)
|
||||
else:
|
||||
false
|
||||
let ma = MultiAddress.init(buffer)
|
||||
if ma.isOk():
|
||||
value = ma.get()
|
||||
ok(true)
|
||||
else:
|
||||
err(ProtoError.IncorrectBlob)
|
||||
|
||||
proc getRepeatedField*(pb: var ProtoBuffer, field: int,
|
||||
value: var seq[MultiAddress]): bool {.inline.} =
|
||||
value: var seq[MultiAddress]): ProtoResult[bool] {.
|
||||
inline.} =
|
||||
var items: seq[seq[byte]]
|
||||
value.setLen(0)
|
||||
if not(getRepeatedField(pb, field, items)):
|
||||
return false
|
||||
if len(items) == 0:
|
||||
return true
|
||||
for item in items:
|
||||
let ma = MultiAddress.init(item)
|
||||
if ma.isOk():
|
||||
value.add(ma.get())
|
||||
else:
|
||||
value.setLen(0)
|
||||
return false
|
||||
let res = ? pb.getRepeatedField(field, items)
|
||||
if not(res):
|
||||
ok(false)
|
||||
else:
|
||||
for item in items:
|
||||
let ma = MultiAddress.init(item)
|
||||
if ma.isOk():
|
||||
value.add(ma.get())
|
||||
else:
|
||||
value.setLen(0)
|
||||
return err(ProtoError.IncorrectBlob)
|
||||
ok(true)
|
||||
|
|
|
@ -219,16 +219,17 @@ proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) =
|
|||
## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
|
||||
write(pb, field, pid.data)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool =
|
||||
proc getField*(pb: ProtoBuffer, field: int,
|
||||
pid: var PeerID): ProtoResult[bool] {.inline.} =
|
||||
## Read ``PeerID`` from ProtoBuf's message and validate it
|
||||
var buffer: seq[byte]
|
||||
var peerId: PeerID
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if peerId.init(buffer):
|
||||
pid = peerId
|
||||
true
|
||||
let res = ? pb.getField(field, buffer)
|
||||
if not(res):
|
||||
ok(false)
|
||||
else:
|
||||
false
|
||||
var peerId: PeerID
|
||||
if peerId.init(buffer):
|
||||
pid = peerId
|
||||
ok(true)
|
||||
else:
|
||||
err(ProtoError.IncorrectBlob)
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import ../varint, stew/endians2
|
||||
import ../varint, stew/[endians2, results]
|
||||
export results
|
||||
|
||||
const
|
||||
MaxMessageSize* = 1'u shl 22
|
||||
|
@ -51,12 +52,15 @@ type
|
|||
of StartGroup, EndGroup:
|
||||
discard
|
||||
|
||||
ProtoResult {.pure.} = enum
|
||||
VarintDecodeError,
|
||||
MessageIncompleteError,
|
||||
BufferOverflowError,
|
||||
MessageSizeTooBigError,
|
||||
NoError
|
||||
ProtoError* {.pure.} = enum
|
||||
VarintDecode,
|
||||
MessageIncomplete,
|
||||
BufferOverflow,
|
||||
MessageTooBig,
|
||||
BadWireType,
|
||||
IncorrectBlob
|
||||
|
||||
ProtoResult*[T] = Result[T, ProtoError]
|
||||
|
||||
ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
|
||||
hint | hint32 | hint64 | float32 | float64
|
||||
|
@ -361,7 +365,8 @@ proc finish*(pb: var ProtoBuffer) =
|
|||
else:
|
||||
pb.offset = 0
|
||||
|
||||
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
|
||||
proc getHeader(data: var ProtoBuffer,
|
||||
header: var ProtoHeader): ProtoResult[void] =
|
||||
var length = 0
|
||||
var hdr = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
|
||||
|
@ -370,34 +375,34 @@ proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
|
|||
if wire in SupportedWireTypes:
|
||||
data.offset += length
|
||||
header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire))
|
||||
true
|
||||
ok()
|
||||
else:
|
||||
false
|
||||
err(ProtoError.BadWireType)
|
||||
else:
|
||||
false
|
||||
err(ProtoError.VarintDecode)
|
||||
|
||||
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
|
||||
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
|
||||
case header.wire
|
||||
of ProtoFieldKind.Varint:
|
||||
var length = 0
|
||||
var value = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
true
|
||||
ok()
|
||||
else:
|
||||
false
|
||||
err(ProtoError.VarintDecode)
|
||||
of ProtoFieldKind.Fixed32:
|
||||
if data.isEnough(sizeof(uint32)):
|
||||
data.offset += sizeof(uint32)
|
||||
true
|
||||
ok()
|
||||
else:
|
||||
false
|
||||
err(ProtoError.VarintDecode)
|
||||
of ProtoFieldKind.Fixed64:
|
||||
if data.isEnough(sizeof(uint64)):
|
||||
data.offset += sizeof(uint64)
|
||||
true
|
||||
ok()
|
||||
else:
|
||||
false
|
||||
err(ProtoError.VarintDecode)
|
||||
of ProtoFieldKind.Length:
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
|
@ -406,19 +411,19 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
|
|||
if bsize <= uint64(MaxMessageSize):
|
||||
if data.isEnough(int(bsize)):
|
||||
data.offset += int(bsize)
|
||||
true
|
||||
ok()
|
||||
else:
|
||||
false
|
||||
err(ProtoError.MessageIncomplete)
|
||||
else:
|
||||
false
|
||||
err(ProtoError.MessageTooBig)
|
||||
else:
|
||||
false
|
||||
err(ProtoError.VarintDecode)
|
||||
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
|
||||
false
|
||||
err(ProtoError.BadWireType)
|
||||
|
||||
proc getValue[T: ProtoScalar](data: var ProtoBuffer,
|
||||
header: ProtoHeader,
|
||||
outval: var T): ProtoResult =
|
||||
outval: var T): ProtoResult[void] =
|
||||
when (T is uint64) or (T is uint32) or (T is uint):
|
||||
doAssert(header.wire == ProtoFieldKind.Varint)
|
||||
var length = 0
|
||||
|
@ -426,9 +431,9 @@ proc getValue[T: ProtoScalar](data: var ProtoBuffer,
|
|||
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
outval = value
|
||||
ProtoResult.NoError
|
||||
ok()
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
err(ProtoError.VarintDecode)
|
||||
elif (T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
doAssert(header.wire == ProtoFieldKind.Varint)
|
||||
|
@ -437,29 +442,29 @@ proc getValue[T: ProtoScalar](data: var ProtoBuffer,
|
|||
if getSVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
outval = value
|
||||
ProtoResult.NoError
|
||||
ok()
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
err(ProtoError.VarintDecode)
|
||||
elif T is float32:
|
||||
doAssert(header.wire == ProtoFieldKind.Fixed32)
|
||||
if data.isEnough(sizeof(float32)):
|
||||
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
|
||||
data.offset += sizeof(float32)
|
||||
ProtoResult.NoError
|
||||
ok()
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
err(ProtoError.MessageIncomplete)
|
||||
elif T is float64:
|
||||
doAssert(header.wire == ProtoFieldKind.Fixed64)
|
||||
if data.isEnough(sizeof(float64)):
|
||||
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
|
||||
data.offset += sizeof(float64)
|
||||
ProtoResult.NoError
|
||||
ok()
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
err(ProtoError.MessageIncomplete)
|
||||
|
||||
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
|
||||
outBytes: var openarray[T],
|
||||
outLength: var int): ProtoResult =
|
||||
outLength: var int): ProtoResult[void] =
|
||||
doAssert(header.wire == ProtoFieldKind.Length)
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
|
@ -474,20 +479,20 @@ proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
|
|||
if bsize > 0'u64:
|
||||
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.NoError
|
||||
ok()
|
||||
else:
|
||||
# Buffer overflow should not be critical failure
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.BufferOverflowError
|
||||
err(ProtoError.BufferOverflow)
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
err(ProtoError.MessageIncomplete)
|
||||
else:
|
||||
ProtoResult.MessageSizeTooBigError
|
||||
err(ProtoError.MessageTooBig)
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
err(ProtoError.VarintDecode)
|
||||
|
||||
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
|
||||
outBytes: var T): ProtoResult =
|
||||
outBytes: var T): ProtoResult[void] =
|
||||
doAssert(header.wire == ProtoFieldKind.Length)
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
|
@ -501,27 +506,24 @@ proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
|
|||
if bsize > 0'u64:
|
||||
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.NoError
|
||||
ok()
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
err(ProtoError.MessageIncomplete)
|
||||
else:
|
||||
ProtoResult.MessageSizeTooBigError
|
||||
err(ProtoError.MessageTooBig)
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
err(ProtoError.VarintDecode)
|
||||
|
||||
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
||||
output: var T): bool =
|
||||
output: var T): ProtoResult[bool] =
|
||||
checkFieldNumber(field)
|
||||
var value: T
|
||||
var current: T
|
||||
var res = false
|
||||
var pb = data
|
||||
output = T(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output = T(0)
|
||||
return false
|
||||
? pb.getHeader(header)
|
||||
let wireCheck =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
|
@ -533,28 +535,29 @@ proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
|||
header.wire == ProtoFieldKind.Fixed64
|
||||
if header.index == uint64(field):
|
||||
if wireCheck:
|
||||
let r = getValue(pb, header, value)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
var value: T
|
||||
let vres = pb.getValue(header, value)
|
||||
if vres.isOk():
|
||||
res = true
|
||||
output = value
|
||||
current = value
|
||||
else:
|
||||
return false
|
||||
return err(vres.error)
|
||||
else:
|
||||
# We are ignoring wire types different from what we expect, because it
|
||||
# is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
output = T(0)
|
||||
return false
|
||||
? pb.skipValue(header)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output = T(0)
|
||||
return false
|
||||
res
|
||||
? pb.skipValue(header)
|
||||
|
||||
if res:
|
||||
output = current
|
||||
ok(true)
|
||||
else:
|
||||
ok(false)
|
||||
|
||||
proc getField*[T: byte|char](data: ProtoBuffer, field: int,
|
||||
output: var openarray[T],
|
||||
outlen: var int): bool =
|
||||
outlen: var int): ProtoResult[bool] =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
var res = false
|
||||
|
@ -563,182 +566,191 @@ proc getField*[T: byte|char](data: ProtoBuffer, field: int,
|
|||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
let hres = pb.getHeader(header)
|
||||
if hres.isErr():
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
|
||||
return err(hres.error)
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
let r = getValue(pb, header, output, outlen)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
let vres = pb.getValue(header, output, outlen)
|
||||
if vres.isOk():
|
||||
res = true
|
||||
of ProtoResult.BufferOverflowError:
|
||||
else:
|
||||
# Buffer overflow error is not critical error, we still can get
|
||||
# field values with proper size.
|
||||
discard
|
||||
else:
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
return false
|
||||
if vres.error != ProtoError.BufferOverflow:
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return err(vres.error)
|
||||
else:
|
||||
# We are ignoring wire types different from ProtoFieldKind.Length,
|
||||
# because it is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = pb.skipValue(header)
|
||||
if sres.isErr():
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
return err(sres.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = pb.skipValue(header)
|
||||
if sres.isErr():
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
return err(sres.error)
|
||||
|
||||
res
|
||||
if res:
|
||||
ok(true)
|
||||
else:
|
||||
ok(false)
|
||||
|
||||
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
|
||||
output: var T): bool =
|
||||
output: var T): ProtoResult[bool] =
|
||||
checkFieldNumber(field)
|
||||
var res = false
|
||||
var pb = data
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
let hres = pb.getHeader(header)
|
||||
if hres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
return err(hres.error)
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
let r = getValue(pb, header, output)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
let vres = pb.getValue(header, output)
|
||||
if vres.isOk():
|
||||
res = true
|
||||
of ProtoResult.BufferOverflowError:
|
||||
# Buffer overflow error is not critical error, we still can get
|
||||
# field values with proper size.
|
||||
discard
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(vres.error)
|
||||
else:
|
||||
# We are ignoring wire types different from ProtoFieldKind.Length,
|
||||
# because it is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = pb.skipValue(header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(sres.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = pb.skipValue(header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
res
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {.
|
||||
inline.} =
|
||||
var buffer: seq[byte]
|
||||
if pb.getField(field, buffer):
|
||||
output = initProtoBuffer(buffer)
|
||||
true
|
||||
return err(sres.error)
|
||||
if res:
|
||||
ok(true)
|
||||
else:
|
||||
false
|
||||
ok(false)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int,
|
||||
output: var ProtoBuffer): ProtoResult[bool] {.inline.} =
|
||||
var buffer: seq[byte]
|
||||
let res = pb.getField(field, buffer)
|
||||
if res.isOk():
|
||||
if res.get():
|
||||
output = initProtoBuffer(buffer)
|
||||
ok(true)
|
||||
else:
|
||||
ok(false)
|
||||
else:
|
||||
err(res.error)
|
||||
|
||||
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
|
||||
output: var seq[T]): bool =
|
||||
output: var seq[T]): ProtoResult[bool] =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
let hres = pb.getHeader(header)
|
||||
if hres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
return err(hres.error)
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
var item: T
|
||||
let r = getValue(pb, header, item)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
let vres = pb.getValue(header, item)
|
||||
if vres.isOk():
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(vres.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = pb.skipValue(header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(sres.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = pb.skipValue(header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(sres.error)
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
ok(true)
|
||||
else:
|
||||
false
|
||||
ok(false)
|
||||
|
||||
proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer,
|
||||
field: int,
|
||||
output: var seq[T]): bool =
|
||||
proc getRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
||||
output: var seq[T]): ProtoResult[bool] =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
let hres = pb.getHeader(header)
|
||||
if hres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(hres.error)
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
|
||||
ProtoFieldKind.Fixed64}:
|
||||
var item: T
|
||||
let r = getValue(pb, header, item)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
let vres = getValue(pb, header, item)
|
||||
if vres.isOk():
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(vres.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = skipValue(pb, header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(sres.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = skipValue(pb, header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(sres.error)
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
ok(true)
|
||||
else:
|
||||
false
|
||||
ok(false)
|
||||
|
||||
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
||||
output: var seq[T]): bool =
|
||||
output: var seq[T]): ProtoResult[bool] =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
let hres = pb.getHeader(header)
|
||||
if hres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(hres.error)
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
var arritem: seq[byte]
|
||||
let rarr = getValue(pb, header, arritem)
|
||||
case rarr
|
||||
of ProtoResult.NoError:
|
||||
let ares = getValue(pb, header, arritem)
|
||||
if ares.isOk():
|
||||
var pbarr = initProtoBuffer(arritem)
|
||||
let itemHeader =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
|
@ -751,29 +763,30 @@ proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
|||
ProtoHeader(wire: ProtoFieldKind.Fixed64)
|
||||
while not(pbarr.isEmpty()):
|
||||
var item: T
|
||||
let res = getValue(pbarr, itemHeader, item)
|
||||
case res
|
||||
of ProtoResult.NoError:
|
||||
let vres = getValue(pbarr, itemHeader, item)
|
||||
if vres.isOk():
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(vres.error)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(ares.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = skipValue(pb, header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(sres.error)
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
let sres = skipValue(pb, header)
|
||||
if sres.isErr():
|
||||
output.setLen(0)
|
||||
return false
|
||||
return err(sres.error)
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
ok(true)
|
||||
else:
|
||||
false
|
||||
ok(false)
|
||||
|
||||
proc getVarintValue*(data: var ProtoBuffer, field: int,
|
||||
value: var SomeVarint): int {.deprecated.} =
|
||||
|
|
|
@ -46,52 +46,56 @@ type
|
|||
|
||||
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
|
||||
result.write(1, peerInfo.publicKey.get().getBytes().tryGet())
|
||||
|
||||
for ma in peerInfo.addrs:
|
||||
result.write(2, ma.data.buffer)
|
||||
|
||||
for proto in peerInfo.protocols:
|
||||
result.write(3, proto)
|
||||
|
||||
result.write(4, observedAddr.data.buffer)
|
||||
|
||||
let protoVersion = ProtoVersion
|
||||
result.write(5, protoVersion)
|
||||
|
||||
let agentVersion = AgentVersion
|
||||
result.write(6, agentVersion)
|
||||
result.finish()
|
||||
|
||||
proc decodeMsg*(buf: seq[byte]): IdentifyInfo =
|
||||
proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] =
|
||||
var
|
||||
iinfo: IdentifyInfo
|
||||
pubKey: PublicKey
|
||||
oaddr: MultiAddress
|
||||
protoVersion: string
|
||||
agentVersion: string
|
||||
|
||||
var pb = initProtoBuffer(buf)
|
||||
|
||||
var pubKey: PublicKey
|
||||
if pb.getField(1, pubKey):
|
||||
trace "read public key from message", pubKey = ($pubKey).shortLog
|
||||
result.pubKey = some(pubKey)
|
||||
let r1 = pb.getField(1, pubKey)
|
||||
let r2 = pb.getRepeatedField(2, iinfo.addrs)
|
||||
let r3 = pb.getRepeatedField(3, iinfo.protos)
|
||||
let r4 = pb.getField(4, oaddr)
|
||||
let r5 = pb.getField(5, protoVersion)
|
||||
let r6 = pb.getField(6, agentVersion)
|
||||
|
||||
if pb.getRepeatedField(2, result.addrs):
|
||||
trace "read addresses from message", addresses = result.addrs
|
||||
let res = r1.isOk() and r2.isOk() and r3.isOk() and
|
||||
r4.isOk() and r5.isOk() and r6.isOk()
|
||||
|
||||
if pb.getRepeatedField(3, result.protos):
|
||||
trace "read protos from message", protocols = result.protos
|
||||
|
||||
var observableAddr: MultiAddress
|
||||
if pb.getField(4, observableAddr):
|
||||
trace "read observableAddr from message", address = observableAddr
|
||||
result.observedAddr = some(observableAddr)
|
||||
|
||||
var protoVersion = ""
|
||||
if pb.getField(5, protoVersion):
|
||||
trace "read protoVersion from message", protoVersion = protoVersion
|
||||
result.protoVersion = some(protoVersion)
|
||||
|
||||
var agentVersion = ""
|
||||
if pb.getField(6, agentVersion):
|
||||
trace "read agentVersion from message", agentVersion = agentVersion
|
||||
result.agentVersion = some(agentVersion)
|
||||
if res:
|
||||
if r1.get():
|
||||
iinfo.pubKey = some(pubKey)
|
||||
if r4.get():
|
||||
iinfo.observedAddr = some(oaddr)
|
||||
if r5.get():
|
||||
iinfo.protoVersion = some(protoVersion)
|
||||
if r6.get():
|
||||
iinfo.agentVersion = some(agentVersion)
|
||||
trace "decodeMsg: decoded message", pubkey = ($pubKey).shortLog,
|
||||
addresses = $iinfo.addrs, protocols = $iinfo.protos,
|
||||
observable_address = $iinfo.observedAddr,
|
||||
proto_version = $iinfo.protoVersion,
|
||||
agent_version = $iinfo.agentVersion
|
||||
some(iinfo)
|
||||
else:
|
||||
trace "decodeMsg: failed to decode received message"
|
||||
none[IdentifyInfo]()
|
||||
|
||||
proc newIdentify*(peerInfo: PeerInfo): Identify =
|
||||
new result
|
||||
|
@ -122,11 +126,13 @@ proc identify*(p: Identify,
|
|||
trace "initiating identify", peer = $conn
|
||||
var message = await conn.readLp(64*1024)
|
||||
if len(message) == 0:
|
||||
trace "identify: Invalid or empty message received!"
|
||||
raise newException(IdentityInvalidMsgError,
|
||||
"Invalid or empty message received!")
|
||||
trace "identify: Empty message received!"
|
||||
raise newException(IdentityInvalidMsgError, "Empty message received!")
|
||||
|
||||
result = decodeMsg(message)
|
||||
let infoOpt = decodeMsg(message)
|
||||
if infoOpt.isNone():
|
||||
raise newException(IdentityInvalidMsgError, "Incorrect message received!")
|
||||
result = infoOpt.get()
|
||||
|
||||
if not isNil(remotePeerInfo) and result.pubKey.isSome:
|
||||
let peer = PeerID.init(result.pubKey.get())
|
||||
|
|
|
@ -43,7 +43,7 @@ type
|
|||
|
||||
RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.}
|
||||
|
||||
func hash*(p: PubSubPeer): Hash =
|
||||
func hash*(p: PubSubPeer): Hash =
|
||||
# int is either 32/64, so intptr basically, pubsubpeer is a ref
|
||||
cast[pointer](p).hash
|
||||
|
||||
|
@ -114,7 +114,13 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
|
|||
trace "message already received, skipping", peer = p.id
|
||||
continue
|
||||
|
||||
var msg = decodeRpcMsg(data)
|
||||
var rmsg = decodeRpcMsg(data)
|
||||
if rmsg.isErr():
|
||||
notice "failed to decode msg from peer", peer = p.id
|
||||
break
|
||||
|
||||
var msg = rmsg.get()
|
||||
|
||||
trace "decoded msg from peer", peer = p.id, msg = msg.shortLog
|
||||
# trigger hooks
|
||||
p.recvObservers(msg)
|
||||
|
@ -149,11 +155,11 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
|
|||
p.sendObservers(mm)
|
||||
|
||||
let encoded = encodeRpcMsg(mm)
|
||||
if encoded.buffer.len <= 0:
|
||||
if encoded.len <= 0:
|
||||
trace "empty message, skipping", peer = p.id
|
||||
return
|
||||
|
||||
let digest = $(sha256.digest(encoded.buffer))
|
||||
let digest = $(sha256.digest(encoded))
|
||||
if digest in p.sentRpcCache:
|
||||
trace "message already sent to peer, skipping", peer = p.id
|
||||
libp2p_pubsub_skipped_sent_messages.inc(labelValues = [p.id])
|
||||
|
@ -164,8 +170,8 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
|
|||
encoded = digest
|
||||
if p.connected: # this can happen if the remote disconnected
|
||||
trace "sending encoded msgs to peer", peer = p.id,
|
||||
encoded = encoded.buffer.shortLog
|
||||
await p.sendConn.writeLp(encoded.buffer)
|
||||
encoded = encoded.shortLog
|
||||
await p.sendConn.writeLp(encoded)
|
||||
p.sentRpcCache.put(digest)
|
||||
|
||||
for m in msgs:
|
||||
|
|
|
@ -80,163 +80,151 @@ proc encodeMessage*(msg: Message): seq[byte] =
|
|||
proc write*(pb: var ProtoBuffer, field: int, msg: Message) =
|
||||
pb.write(field, encodeMessage(msg))
|
||||
|
||||
proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} =
|
||||
proc decodeGraft*(pb: ProtoBuffer): ProtoResult[ControlGraft] {.
|
||||
inline.} =
|
||||
trace "decodeGraft: decoding message"
|
||||
var control = ControlGraft()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodeGraft: read topicId", topic_id = topicId
|
||||
if ? pb.getField(1, control.topicId):
|
||||
trace "decodeGraft: read topicId", topic_id = control.topicId
|
||||
else:
|
||||
trace "decodeGraft: topicId is missing"
|
||||
control
|
||||
ok(control)
|
||||
|
||||
proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} =
|
||||
proc decodePrune*(pb: ProtoBuffer): ProtoResult[ControlPrune] {.
|
||||
inline.} =
|
||||
trace "decodePrune: decoding message"
|
||||
var control = ControlPrune()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodePrune: read topicId", topic_id = topicId
|
||||
if ? pb.getField(1, control.topicId):
|
||||
trace "decodePrune: read topicId", topic_id = control.topicId
|
||||
else:
|
||||
trace "decodePrune: topicId is missing"
|
||||
control
|
||||
ok(control)
|
||||
|
||||
proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} =
|
||||
proc decodeIHave*(pb: ProtoBuffer): ProtoResult[ControlIHave] {.
|
||||
inline.} =
|
||||
trace "decodeIHave: decoding message"
|
||||
var control = ControlIHave()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodeIHave: read topicId", topic_id = topicId
|
||||
if ? pb.getField(1, control.topicId):
|
||||
trace "decodeIHave: read topicId", topic_id = control.topicId
|
||||
else:
|
||||
trace "decodeIHave: topicId is missing"
|
||||
if pb.getRepeatedField(2, control.messageIDs):
|
||||
if ? pb.getRepeatedField(2, control.messageIDs):
|
||||
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIHave: no messageIDs"
|
||||
control
|
||||
ok(control)
|
||||
|
||||
proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} =
|
||||
proc decodeIWant*(pb: ProtoBuffer): ProtoResult[ControlIWant] {.inline.} =
|
||||
trace "decodeIWant: decoding message"
|
||||
var control = ControlIWant()
|
||||
if pb.getRepeatedField(1, control.messageIDs):
|
||||
if ? pb.getRepeatedField(1, control.messageIDs):
|
||||
trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIWant: no messageIDs"
|
||||
ok(control)
|
||||
|
||||
proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} =
|
||||
proc decodeControl*(pb: ProtoBuffer): ProtoResult[Option[ControlMessage]] {.
|
||||
inline.} =
|
||||
trace "decodeControl: decoding message"
|
||||
var buffer: seq[byte]
|
||||
if pb.getField(3, buffer):
|
||||
if ? pb.getField(3, buffer):
|
||||
var control: ControlMessage
|
||||
var cpb = initProtoBuffer(buffer)
|
||||
var ihavepbs: seq[seq[byte]]
|
||||
var iwantpbs: seq[seq[byte]]
|
||||
var graftpbs: seq[seq[byte]]
|
||||
var prunepbs: seq[seq[byte]]
|
||||
|
||||
discard cpb.getRepeatedField(1, ihavepbs)
|
||||
discard cpb.getRepeatedField(2, iwantpbs)
|
||||
discard cpb.getRepeatedField(3, graftpbs)
|
||||
discard cpb.getRepeatedField(4, prunepbs)
|
||||
|
||||
for item in ihavepbs:
|
||||
control.ihave.add(decodeIHave(initProtoBuffer(item)))
|
||||
for item in iwantpbs:
|
||||
control.iwant.add(decodeIWant(initProtoBuffer(item)))
|
||||
for item in graftpbs:
|
||||
control.graft.add(decodeGraft(initProtoBuffer(item)))
|
||||
for item in prunepbs:
|
||||
control.prune.add(decodePrune(initProtoBuffer(item)))
|
||||
|
||||
trace "decodeControl: "
|
||||
some(control)
|
||||
if ? cpb.getRepeatedField(1, ihavepbs):
|
||||
for item in ihavepbs:
|
||||
control.ihave.add(? decodeIHave(initProtoBuffer(item)))
|
||||
if ? cpb.getRepeatedField(2, iwantpbs):
|
||||
for item in iwantpbs:
|
||||
control.iwant.add(? decodeIWant(initProtoBuffer(item)))
|
||||
if ? cpb.getRepeatedField(3, graftpbs):
|
||||
for item in graftpbs:
|
||||
control.graft.add(? decodeGraft(initProtoBuffer(item)))
|
||||
if ? cpb.getRepeatedField(4, prunepbs):
|
||||
for item in prunepbs:
|
||||
control.prune.add(? decodePrune(initProtoBuffer(item)))
|
||||
trace "decodeControl: message statistics", graft_count = len(control.graft),
|
||||
prune_count = len(control.prune),
|
||||
ihave_count = len(control.ihave),
|
||||
iwant_count = len(control.iwant)
|
||||
ok(some(control))
|
||||
else:
|
||||
none[ControlMessage]()
|
||||
ok(none[ControlMessage]())
|
||||
|
||||
proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} =
|
||||
proc decodeSubscription*(pb: ProtoBuffer): ProtoResult[SubOpts] {.inline.} =
|
||||
trace "decodeSubscription: decoding message"
|
||||
var subflag: uint64
|
||||
var sub = SubOpts()
|
||||
if pb.getField(1, subflag):
|
||||
if ? pb.getField(1, subflag):
|
||||
sub.subscribe = bool(subflag)
|
||||
trace "decodeSubscription: read subscribe", subscribe = subflag
|
||||
else:
|
||||
trace "decodeSubscription: subscribe is missing"
|
||||
if pb.getField(2, sub.topic):
|
||||
if ? pb.getField(2, sub.topic):
|
||||
trace "decodeSubscription: read topic", topic = sub.topic
|
||||
else:
|
||||
trace "decodeSubscription: topic is missing"
|
||||
ok(sub)
|
||||
|
||||
sub
|
||||
|
||||
proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} =
|
||||
proc decodeSubscriptions*(pb: ProtoBuffer): ProtoResult[seq[SubOpts]] {.
|
||||
inline.} =
|
||||
trace "decodeSubscriptions: decoding message"
|
||||
var subpbs: seq[seq[byte]]
|
||||
var subs: seq[SubOpts]
|
||||
if pb.getRepeatedField(1, subpbs):
|
||||
let res = ? pb.getRepeatedField(1, subpbs)
|
||||
if res:
|
||||
trace "decodeSubscriptions: read subscriptions", count = len(subpbs)
|
||||
for item in subpbs:
|
||||
let sub = decodeSubscription(initProtoBuffer(item))
|
||||
subs.add(sub)
|
||||
subs.add(? decodeSubscription(initProtoBuffer(item)))
|
||||
if len(subs) == 0:
|
||||
trace "decodeSubscription: no subscriptions found"
|
||||
ok(subs)
|
||||
|
||||
if len(subs) == 0:
|
||||
trace "decodeSubscription: no subscriptions found"
|
||||
|
||||
subs
|
||||
|
||||
proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} =
|
||||
proc decodeMessage*(pb: ProtoBuffer): ProtoResult[Message] {.inline.} =
|
||||
trace "decodeMessage: decoding message"
|
||||
var msg: Message
|
||||
if pb.getField(1, msg.fromPeer):
|
||||
if ? pb.getField(1, msg.fromPeer):
|
||||
trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty()
|
||||
else:
|
||||
trace "decodeMessage: fromPeer is missing"
|
||||
|
||||
if pb.getField(2, msg.data):
|
||||
if ? pb.getField(2, msg.data):
|
||||
trace "decodeMessage: read data", data = msg.data.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: data is missing"
|
||||
|
||||
if pb.getField(3, msg.seqno):
|
||||
if ? pb.getField(3, msg.seqno):
|
||||
trace "decodeMessage: read seqno", seqno = msg.data.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: seqno is missing"
|
||||
|
||||
if pb.getRepeatedField(4, msg.topicIDs):
|
||||
if ? pb.getRepeatedField(4, msg.topicIDs):
|
||||
trace "decodeMessage: read topics", topic_ids = msg.topicIDs
|
||||
else:
|
||||
trace "decodeMessage: topics are missing"
|
||||
|
||||
if pb.getField(5, msg.signature):
|
||||
if ? pb.getField(5, msg.signature):
|
||||
trace "decodeMessage: read signature", signature = msg.signature.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: signature is missing"
|
||||
|
||||
if pb.getField(6, msg.key):
|
||||
if ? pb.getField(6, msg.key):
|
||||
trace "decodeMessage: read public key", key = msg.key.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: public key is missing"
|
||||
ok(msg)
|
||||
|
||||
msg
|
||||
|
||||
proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} =
|
||||
proc decodeMessages*(pb: ProtoBuffer): ProtoResult[seq[Message]] {.inline.} =
|
||||
trace "decodeMessages: decoding message"
|
||||
var msgpbs: seq[seq[byte]]
|
||||
var msgs: seq[Message]
|
||||
if pb.getRepeatedField(2, msgpbs):
|
||||
if ? pb.getRepeatedField(2, msgpbs):
|
||||
trace "decodeMessages: read messages", count = len(msgpbs)
|
||||
for item in msgpbs:
|
||||
let msg = decodeMessage(initProtoBuffer(item))
|
||||
msgs.add(msg)
|
||||
|
||||
if len(msgs) == 0:
|
||||
msgs.add(? decodeMessage(initProtoBuffer(item)))
|
||||
else:
|
||||
trace "decodeMessages: no messages found"
|
||||
ok(msgs)
|
||||
|
||||
msgs
|
||||
|
||||
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
|
||||
proc encodeRpcMsg*(msg: RPCMsg): seq[byte] =
|
||||
trace "encodeRpcMsg: encoding message", msg = msg.shortLog()
|
||||
var pb = initProtoBuffer()
|
||||
for item in msg.subscriptions:
|
||||
|
@ -247,14 +235,13 @@ proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
|
|||
pb.write(3, msg.control.get())
|
||||
if len(pb.buffer) > 0:
|
||||
pb.finish()
|
||||
result = pb
|
||||
pb.buffer
|
||||
|
||||
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg =
|
||||
proc decodeRpcMsg*(msg: seq[byte]): ProtoResult[RPCMsg] {.inline.} =
|
||||
trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
|
||||
var pb = initProtoBuffer(msg)
|
||||
var rpcMsg: RPCMsg
|
||||
rpcMsg.messages = pb.decodeMessages()
|
||||
rpcMsg.subscriptions = pb.decodeSubscriptions()
|
||||
rpcMsg.control = pb.decodeControl()
|
||||
|
||||
rpcMsg
|
||||
rpcMsg.messages = ? pb.decodeMessages()
|
||||
rpcMsg.subscriptions = ? pb.decodeSubscriptions()
|
||||
rpcMsg.control = ? pb.decodeControl()
|
||||
ok(rpcMsg)
|
||||
|
|
|
@ -449,9 +449,11 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
remoteSig: Signature
|
||||
remoteSigBytes: seq[byte]
|
||||
|
||||
if not(remoteProof.getField(1, remotePubKeyBytes)):
|
||||
let r1 = remoteProof.getField(1, remotePubKeyBytes)
|
||||
let r2 = remoteProof.getField(2, remoteSigBytes)
|
||||
if r1.isErr() or not(r1.get()):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if not(remoteProof.getField(2, remoteSigBytes)):
|
||||
if r2.isErr() or not(r2.get()):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
|
||||
if not remotePubKey.init(remotePubKeyBytes):
|
||||
|
|
|
@ -88,7 +88,7 @@ suite "MinProtobuf test suite":
|
|||
var value: uint64
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
doAssert(res.isOk() == true and res.get() == true)
|
||||
value
|
||||
|
||||
proc getFixed32EncodedValue(value: float32): seq[byte] =
|
||||
|
@ -101,7 +101,7 @@ suite "MinProtobuf test suite":
|
|||
var value: float32
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
doAssert(res.isOk() == true and res.get() == true)
|
||||
cast[uint32](value)
|
||||
|
||||
proc getFixed64EncodedValue(value: float64): seq[byte] =
|
||||
|
@ -114,7 +114,7 @@ suite "MinProtobuf test suite":
|
|||
var value: float64
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
doAssert(res.isOk() == true and res.get() == true)
|
||||
cast[uint64](value)
|
||||
|
||||
proc getLengthEncodedValue(value: string): seq[byte] =
|
||||
|
@ -134,8 +134,7 @@ suite "MinProtobuf test suite":
|
|||
var valueLen = 0
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value, valueLen)
|
||||
|
||||
doAssert(res)
|
||||
doAssert(res.isOk() == true and res.get() == true)
|
||||
value.setLen(valueLen)
|
||||
value
|
||||
|
||||
|
@ -173,17 +172,19 @@ suite "MinProtobuf test suite":
|
|||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[varint] non-existent field test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
var value: uint64
|
||||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(2, value)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == 0'u64
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
|
||||
test "[varint] corrupted header test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
|
@ -192,15 +193,17 @@ suite "MinProtobuf test suite":
|
|||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[varint] empty buffer test":
|
||||
var value: uint64
|
||||
var pb = initProtoBuffer()
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == 0'u64
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
|
||||
test "[varint] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
|
@ -218,9 +221,12 @@ suite "MinProtobuf test suite":
|
|||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r1.get() == true
|
||||
r2.get() == true
|
||||
r3.get() == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
|
@ -246,9 +252,12 @@ suite "MinProtobuf test suite":
|
|||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r1.get() == true
|
||||
r2.get() == true
|
||||
r3.get() == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 6
|
||||
|
@ -284,17 +293,19 @@ suite "MinProtobuf test suite":
|
|||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[fixed32] non-existent field test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
var value: float32
|
||||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(2, value)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == float32(0)
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
|
||||
test "[fixed32] corrupted header test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
|
@ -303,15 +314,17 @@ suite "MinProtobuf test suite":
|
|||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[fixed32] empty buffer test":
|
||||
var value: float32
|
||||
var pb = initProtoBuffer()
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == float32(0)
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
|
||||
test "[fixed32] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
|
@ -329,9 +342,12 @@ suite "MinProtobuf test suite":
|
|||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r1.get() == true
|
||||
r2.get() == true
|
||||
r3.get() == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
|
@ -360,9 +376,12 @@ suite "MinProtobuf test suite":
|
|||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r1.get() == true
|
||||
r2.get() == true
|
||||
r3.get() == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 5
|
||||
|
@ -397,17 +416,19 @@ suite "MinProtobuf test suite":
|
|||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[fixed64] non-existent field test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
var value: float64
|
||||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(2, value)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == float64(0)
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
|
||||
test "[fixed64] corrupted header test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
|
@ -416,15 +437,17 @@ suite "MinProtobuf test suite":
|
|||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[fixed64] empty buffer test":
|
||||
var value: float64
|
||||
var pb = initProtoBuffer()
|
||||
let res = pb.getField(1, value)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == float64(0)
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
|
||||
test "[fixed64] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
|
@ -442,9 +465,12 @@ suite "MinProtobuf test suite":
|
|||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r1.get() == true
|
||||
r2.get() == true
|
||||
r3.get() == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
|
@ -474,9 +500,12 @@ suite "MinProtobuf test suite":
|
|||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r1.get() == true
|
||||
r2.get() == true
|
||||
r3.get() == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 8
|
||||
|
@ -523,8 +552,9 @@ suite "MinProtobuf test suite":
|
|||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value, valueLen)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[length] non-existent field test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
|
@ -532,8 +562,10 @@ suite "MinProtobuf test suite":
|
|||
var valueLen = 0
|
||||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(2, value, valueLen)
|
||||
check:
|
||||
pb.getField(2, value, valueLen) == false
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
valueLen == 0
|
||||
|
||||
test "[length] corrupted header test":
|
||||
|
@ -544,15 +576,18 @@ suite "MinProtobuf test suite":
|
|||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value, valueLen)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[length] empty buffer test":
|
||||
var value = newSeq[byte](len(LengthValues[0]))
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer()
|
||||
let res = pb.getField(1, value, valueLen)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
valueLen == 0
|
||||
|
||||
test "[length] buffer overflow test":
|
||||
|
@ -562,8 +597,10 @@ suite "MinProtobuf test suite":
|
|||
var value = newString(len(LengthValues[i]) - 1)
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value, valueLen)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
res.isOk() == true
|
||||
res.get() == false
|
||||
valueLen == len(LengthValues[i])
|
||||
isFullZero(value) == true
|
||||
|
||||
|
@ -578,8 +615,10 @@ suite "MinProtobuf test suite":
|
|||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var value = newString(4)
|
||||
var valueLen = 0
|
||||
let res = pb2.getField(1, value, valueLen)
|
||||
check:
|
||||
pb2.getField(1, value, valueLen) == true
|
||||
res.isOk() == true
|
||||
res.get() == true
|
||||
value == "SOME"
|
||||
|
||||
test "[length] too big message test":
|
||||
|
@ -593,8 +632,9 @@ suite "MinProtobuf test suite":
|
|||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var value = newString(MaxMessageSize + 1)
|
||||
var valueLen = 0
|
||||
let res = pb2.getField(1, value, valueLen)
|
||||
check:
|
||||
pb2.getField(1, value, valueLen) == false
|
||||
res.isErr() == true
|
||||
|
||||
test "[length] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
|
@ -612,9 +652,12 @@ suite "MinProtobuf test suite":
|
|||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r1.get() == true
|
||||
r2.get() == true
|
||||
r3.get() == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
|
@ -662,11 +705,16 @@ suite "MinProtobuf test suite":
|
|||
var lengthValue = newString(10)
|
||||
var lengthSize: int
|
||||
|
||||
let r1 = pb.getField(1, varintValue)
|
||||
let r2 = pb.getField(2, fixed32Value)
|
||||
let r3 = pb.getField(3, fixed64Value)
|
||||
let r4 = pb.getField(4, lengthValue, lengthSize)
|
||||
|
||||
check:
|
||||
pb.getField(1, varintValue) == true
|
||||
pb.getField(2, fixed32Value) == true
|
||||
pb.getField(3, fixed64Value) == true
|
||||
pb.getField(4, lengthValue, lengthSize) == true
|
||||
r1.isOk() == true
|
||||
r2.isOk() == true
|
||||
r3.isOk() == true
|
||||
r4.isOk() == true
|
||||
|
||||
lengthValue.setLen(lengthSize)
|
||||
|
||||
|
|
Loading…
Reference in New Issue