Minprotobuf refactoring 2 (#269)

* Protobuf refactoring stage II.

* Remove NoError.

* Change trace level for invalid message.
This commit is contained in:
Eugene Kabanov 2020-07-15 11:25:39 +03:00 committed by GitHub
parent 9eb5828a42
commit b832668768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 473 additions and 393 deletions

View File

@ -31,11 +31,6 @@ type
ECDSA, ECDSA,
NoSupport NoSupport
CipherScheme* = enum
Aes128 = 0,
Aes256,
Blowfish
DigestSheme* = enum DigestSheme* = enum
Sha256, Sha256,
Sha512 Sha512
@ -283,7 +278,9 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool =
var buffer: seq[byte] var buffer: seq[byte]
if len(data) > 0: if len(data) > 0:
var pb = initProtoBuffer(@data) var pb = initProtoBuffer(@data)
if pb.getField(1, id) and pb.getField(2, buffer): let r1 = pb.getField(1, id)
let r2 = pb.getField(2, buffer)
if r1.isOk() and r1.get() and r2.isOk() and r2.get():
if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0: if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0:
var scheme = cast[PKScheme](cast[int8](id)) var scheme = cast[PKScheme](cast[int8](id))
when key is PrivateKey: when key is PrivateKey:
@ -743,9 +740,15 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
## ##
## Procedure returns ``true`` on success and ``false`` on error. ## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message) var pb = initProtoBuffer(message)
pb.getField(1, nonce) and pb.getField(2, pubkey) and let r1 = pb.getField(1, nonce)
pb.getField(3, exchanges) and pb.getField(4, ciphers) and let r2 = pb.getField(2, pubkey)
pb.getField(5, hashes) let r3 = pb.getField(3, exchanges)
let r4 = pb.getField(4, ciphers)
let r5 = pb.getField(5, hashes)
r1.isOk() and r1.get() and r2.isOk() and r2.get() and
r3.isOk() and r3.get() and r4.isOk() and r4.get() and
r5.isOk() and r5.get()
proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] = proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] =
## Create SecIO exchange message using ephemeral public key ``epubkey`` and ## Create SecIO exchange message using ephemeral public key ``epubkey`` and
@ -763,7 +766,9 @@ proc decodeExchange*(message: seq[byte],
## ##
## Procedure returns ``true`` on success and ``false`` on error. ## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message) var pb = initProtoBuffer(message)
pb.getField(1, pubkey) and pb.getField(2, signature) let r1 = pb.getField(1, pubkey)
let r2 = pb.getField(2, signature)
r1.isOk() and r1.get() and r2.isOk() and r2.get()
## Serialization/Deserialization helpers ## Serialization/Deserialization helpers
@ -825,28 +830,37 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {.
value = sig value = sig
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int, proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
value: var T): bool = value: var T): ProtoResult[bool] =
## Deserialize public/private key from protobuf's message ``pb`` using field
## index ``field``.
##
## On success deserialized key will be stored in ``value``.
var buffer: seq[byte] var buffer: seq[byte]
var key: T var key: T
if not(getField(pb, field, buffer)): let res = ? pb.getField(field, buffer)
return false if not(res):
if len(buffer) == 0: ok(false)
return false
if key.init(buffer):
value = key
true
else: else:
false if key.init(buffer):
value = key
ok(true)
else:
err(ProtoError.IncorrectBlob)
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool = proc getField*(pb: ProtoBuffer, field: int,
value: var Signature): ProtoResult[bool] =
## Deserialize signature from protobuf's message ``pb`` using field index
## ``field``.
##
## On success deserialized signature will be stored in ``value``.
var buffer: seq[byte] var buffer: seq[byte]
var sig: Signature var sig: Signature
if not(getField(pb, field, buffer)): let res = ? pb.getField(field, buffer)
return false if not(res):
if len(buffer) == 0: ok(false)
return false
if sig.init(buffer):
value = sig
true
else: else:
false if sig.init(buffer):
value = sig
ok(true)
else:
err(ProtoError.IncorrectBlob)

View File

@ -1025,31 +1025,34 @@ proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
write(pb, field, value.data.buffer) write(pb, field, value.data.buffer)
proc getField*(pb: var ProtoBuffer, field: int, proc getField*(pb: var ProtoBuffer, field: int,
value: var MultiAddress): bool {.inline.} = value: var MultiAddress): ProtoResult[bool] {.
inline.} =
var buffer: seq[byte] var buffer: seq[byte]
if not(getField(pb, field, buffer)): let res = ? pb.getField(field, buffer)
return false if not(res):
if len(buffer) == 0: ok(false)
return false
let ma = MultiAddress.init(buffer)
if ma.isOk():
value = ma.get()
true
else: else:
false let ma = MultiAddress.init(buffer)
if ma.isOk():
value = ma.get()
ok(true)
else:
err(ProtoError.IncorrectBlob)
proc getRepeatedField*(pb: var ProtoBuffer, field: int, proc getRepeatedField*(pb: var ProtoBuffer, field: int,
value: var seq[MultiAddress]): bool {.inline.} = value: var seq[MultiAddress]): ProtoResult[bool] {.
inline.} =
var items: seq[seq[byte]] var items: seq[seq[byte]]
value.setLen(0) value.setLen(0)
if not(getRepeatedField(pb, field, items)): let res = ? pb.getRepeatedField(field, items)
return false if not(res):
if len(items) == 0: ok(false)
return true else:
for item in items: for item in items:
let ma = MultiAddress.init(item) let ma = MultiAddress.init(item)
if ma.isOk(): if ma.isOk():
value.add(ma.get()) value.add(ma.get())
else: else:
value.setLen(0) value.setLen(0)
return false return err(ProtoError.IncorrectBlob)
ok(true)

View File

@ -219,16 +219,17 @@ proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) =
## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding. ## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
write(pb, field, pid.data) write(pb, field, pid.data)
proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool = proc getField*(pb: ProtoBuffer, field: int,
pid: var PeerID): ProtoResult[bool] {.inline.} =
## Read ``PeerID`` from ProtoBuf's message and validate it ## Read ``PeerID`` from ProtoBuf's message and validate it
var buffer: seq[byte] var buffer: seq[byte]
var peerId: PeerID let res = ? pb.getField(field, buffer)
if not(getField(pb, field, buffer)): if not(res):
return false ok(false)
if len(buffer) == 0:
return false
if peerId.init(buffer):
pid = peerId
true
else: else:
false var peerId: PeerID
if peerId.init(buffer):
pid = peerId
ok(true)
else:
err(ProtoError.IncorrectBlob)

View File

@ -11,7 +11,8 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import ../varint, stew/endians2 import ../varint, stew/[endians2, results]
export results
const const
MaxMessageSize* = 1'u shl 22 MaxMessageSize* = 1'u shl 22
@ -51,12 +52,15 @@ type
of StartGroup, EndGroup: of StartGroup, EndGroup:
discard discard
ProtoResult {.pure.} = enum ProtoError* {.pure.} = enum
VarintDecodeError, VarintDecode,
MessageIncompleteError, MessageIncomplete,
BufferOverflowError, BufferOverflow,
MessageSizeTooBigError, MessageTooBig,
NoError BadWireType,
IncorrectBlob
ProtoResult*[T] = Result[T, ProtoError]
ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 | ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
hint | hint32 | hint64 | float32 | float64 hint | hint32 | hint64 | float32 | float64
@ -361,7 +365,8 @@ proc finish*(pb: var ProtoBuffer) =
else: else:
pb.offset = 0 pb.offset = 0
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool = proc getHeader(data: var ProtoBuffer,
header: var ProtoHeader): ProtoResult[void] =
var length = 0 var length = 0
var hdr = 0'u64 var hdr = 0'u64
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk(): if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
@ -370,34 +375,34 @@ proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
if wire in SupportedWireTypes: if wire in SupportedWireTypes:
data.offset += length data.offset += length
header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire)) header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire))
true ok()
else: else:
false err(ProtoError.BadWireType)
else: else:
false err(ProtoError.VarintDecode)
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool = proc skipValue(data: var ProtoBuffer, header: ProtoHeader): ProtoResult[void] =
case header.wire case header.wire
of ProtoFieldKind.Varint: of ProtoFieldKind.Varint:
var length = 0 var length = 0
var value = 0'u64 var value = 0'u64
if PB.getUVarint(data.toOpenArray(), length, value).isOk(): if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length data.offset += length
true ok()
else: else:
false err(ProtoError.VarintDecode)
of ProtoFieldKind.Fixed32: of ProtoFieldKind.Fixed32:
if data.isEnough(sizeof(uint32)): if data.isEnough(sizeof(uint32)):
data.offset += sizeof(uint32) data.offset += sizeof(uint32)
true ok()
else: else:
false err(ProtoError.VarintDecode)
of ProtoFieldKind.Fixed64: of ProtoFieldKind.Fixed64:
if data.isEnough(sizeof(uint64)): if data.isEnough(sizeof(uint64)):
data.offset += sizeof(uint64) data.offset += sizeof(uint64)
true ok()
else: else:
false err(ProtoError.VarintDecode)
of ProtoFieldKind.Length: of ProtoFieldKind.Length:
var length = 0 var length = 0
var bsize = 0'u64 var bsize = 0'u64
@ -406,19 +411,19 @@ proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
if bsize <= uint64(MaxMessageSize): if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)): if data.isEnough(int(bsize)):
data.offset += int(bsize) data.offset += int(bsize)
true ok()
else: else:
false err(ProtoError.MessageIncomplete)
else: else:
false err(ProtoError.MessageTooBig)
else: else:
false err(ProtoError.VarintDecode)
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup: of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
false err(ProtoError.BadWireType)
proc getValue[T: ProtoScalar](data: var ProtoBuffer, proc getValue[T: ProtoScalar](data: var ProtoBuffer,
header: ProtoHeader, header: ProtoHeader,
outval: var T): ProtoResult = outval: var T): ProtoResult[void] =
when (T is uint64) or (T is uint32) or (T is uint): when (T is uint64) or (T is uint32) or (T is uint):
doAssert(header.wire == ProtoFieldKind.Varint) doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0 var length = 0
@ -426,9 +431,9 @@ proc getValue[T: ProtoScalar](data: var ProtoBuffer,
if PB.getUVarint(data.toOpenArray(), length, value).isOk(): if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length data.offset += length
outval = value outval = value
ProtoResult.NoError ok()
else: else:
ProtoResult.VarintDecodeError err(ProtoError.VarintDecode)
elif (T is zint64) or (T is zint32) or (T is zint) or elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint): (T is hint64) or (T is hint32) or (T is hint):
doAssert(header.wire == ProtoFieldKind.Varint) doAssert(header.wire == ProtoFieldKind.Varint)
@ -437,29 +442,29 @@ proc getValue[T: ProtoScalar](data: var ProtoBuffer,
if getSVarint(data.toOpenArray(), length, value).isOk(): if getSVarint(data.toOpenArray(), length, value).isOk():
data.offset += length data.offset += length
outval = value outval = value
ProtoResult.NoError ok()
else: else:
ProtoResult.VarintDecodeError err(ProtoError.VarintDecode)
elif T is float32: elif T is float32:
doAssert(header.wire == ProtoFieldKind.Fixed32) doAssert(header.wire == ProtoFieldKind.Fixed32)
if data.isEnough(sizeof(float32)): if data.isEnough(sizeof(float32)):
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray())) outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
data.offset += sizeof(float32) data.offset += sizeof(float32)
ProtoResult.NoError ok()
else: else:
ProtoResult.MessageIncompleteError err(ProtoError.MessageIncomplete)
elif T is float64: elif T is float64:
doAssert(header.wire == ProtoFieldKind.Fixed64) doAssert(header.wire == ProtoFieldKind.Fixed64)
if data.isEnough(sizeof(float64)): if data.isEnough(sizeof(float64)):
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray())) outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
data.offset += sizeof(float64) data.offset += sizeof(float64)
ProtoResult.NoError ok()
else: else:
ProtoResult.MessageIncompleteError err(ProtoError.MessageIncomplete)
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader, proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var openarray[T], outBytes: var openarray[T],
outLength: var int): ProtoResult = outLength: var int): ProtoResult[void] =
doAssert(header.wire == ProtoFieldKind.Length) doAssert(header.wire == ProtoFieldKind.Length)
var length = 0 var length = 0
var bsize = 0'u64 var bsize = 0'u64
@ -474,20 +479,20 @@ proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
if bsize > 0'u64: if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize) data.offset += int(bsize)
ProtoResult.NoError ok()
else: else:
# Buffer overflow should not be critical failure # Buffer overflow should not be critical failure
data.offset += int(bsize) data.offset += int(bsize)
ProtoResult.BufferOverflowError err(ProtoError.BufferOverflow)
else: else:
ProtoResult.MessageIncompleteError err(ProtoError.MessageIncomplete)
else: else:
ProtoResult.MessageSizeTooBigError err(ProtoError.MessageTooBig)
else: else:
ProtoResult.VarintDecodeError err(ProtoError.VarintDecode)
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader, proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var T): ProtoResult = outBytes: var T): ProtoResult[void] =
doAssert(header.wire == ProtoFieldKind.Length) doAssert(header.wire == ProtoFieldKind.Length)
var length = 0 var length = 0
var bsize = 0'u64 var bsize = 0'u64
@ -501,27 +506,24 @@ proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
if bsize > 0'u64: if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize) data.offset += int(bsize)
ProtoResult.NoError ok()
else: else:
ProtoResult.MessageIncompleteError err(ProtoError.MessageIncomplete)
else: else:
ProtoResult.MessageSizeTooBigError err(ProtoError.MessageTooBig)
else: else:
ProtoResult.VarintDecodeError err(ProtoError.VarintDecode)
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int, proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var T): bool = output: var T): ProtoResult[bool] =
checkFieldNumber(field) checkFieldNumber(field)
var value: T var current: T
var res = false var res = false
var pb = data var pb = data
output = T(0)
while not(pb.isEmpty()): while not(pb.isEmpty()):
var header: ProtoHeader var header: ProtoHeader
if not(pb.getHeader(header)): ? pb.getHeader(header)
output = T(0)
return false
let wireCheck = let wireCheck =
when (T is uint64) or (T is uint32) or (T is uint) or when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or (T is zint64) or (T is zint32) or (T is zint) or
@ -533,28 +535,29 @@ proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
header.wire == ProtoFieldKind.Fixed64 header.wire == ProtoFieldKind.Fixed64
if header.index == uint64(field): if header.index == uint64(field):
if wireCheck: if wireCheck:
let r = getValue(pb, header, value) var value: T
case r let vres = pb.getValue(header, value)
of ProtoResult.NoError: if vres.isOk():
res = true res = true
output = value current = value
else: else:
return false return err(vres.error)
else: else:
# We are ignoring wire types different from what we expect, because it # We are ignoring wire types different from what we expect, because it
# is how `protoc` is working. # is how `protoc` is working.
if not(skipValue(pb, header)): ? pb.skipValue(header)
output = T(0)
return false
else: else:
if not(skipValue(pb, header)): ? pb.skipValue(header)
output = T(0)
return false if res:
res output = current
ok(true)
else:
ok(false)
proc getField*[T: byte|char](data: ProtoBuffer, field: int, proc getField*[T: byte|char](data: ProtoBuffer, field: int,
output: var openarray[T], output: var openarray[T],
outlen: var int): bool = outlen: var int): ProtoResult[bool] =
checkFieldNumber(field) checkFieldNumber(field)
var pb = data var pb = data
var res = false var res = false
@ -563,182 +566,191 @@ proc getField*[T: byte|char](data: ProtoBuffer, field: int,
while not(pb.isEmpty()): while not(pb.isEmpty()):
var header: ProtoHeader var header: ProtoHeader
if not(pb.getHeader(header)): let hres = pb.getHeader(header)
if hres.isErr():
if len(output) > 0: if len(output) > 0:
zeroMem(addr output[0], len(output)) zeroMem(addr output[0], len(output))
outlen = 0 outlen = 0
return false return err(hres.error)
if header.index == uint64(field): if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length: if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output, outlen) let vres = pb.getValue(header, output, outlen)
case r if vres.isOk():
of ProtoResult.NoError:
res = true res = true
of ProtoResult.BufferOverflowError: else:
# Buffer overflow error is not critical error, we still can get # Buffer overflow error is not critical error, we still can get
# field values with proper size. # field values with proper size.
discard if vres.error != ProtoError.BufferOverflow:
else: if len(output) > 0:
if len(output) > 0: zeroMem(addr output[0], len(output))
zeroMem(addr output[0], len(output)) outlen = 0
return false return err(vres.error)
else: else:
# We are ignoring wire types different from ProtoFieldKind.Length, # We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working. # because it is how `protoc` is working.
if not(skipValue(pb, header)): let sres = pb.skipValue(header)
if sres.isErr():
if len(output) > 0: if len(output) > 0:
zeroMem(addr output[0], len(output)) zeroMem(addr output[0], len(output))
outlen = 0 outlen = 0
return false return err(sres.error)
else: else:
if not(skipValue(pb, header)): let sres = pb.skipValue(header)
if sres.isErr():
if len(output) > 0: if len(output) > 0:
zeroMem(addr output[0], len(output)) zeroMem(addr output[0], len(output))
outlen = 0 outlen = 0
return false return err(sres.error)
res if res:
ok(true)
else:
ok(false)
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int, proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var T): bool = output: var T): ProtoResult[bool] =
checkFieldNumber(field) checkFieldNumber(field)
var res = false var res = false
var pb = data var pb = data
while not(pb.isEmpty()): while not(pb.isEmpty()):
var header: ProtoHeader var header: ProtoHeader
if not(pb.getHeader(header)): let hres = pb.getHeader(header)
if hres.isErr():
output.setLen(0) output.setLen(0)
return false return err(hres.error)
if header.index == uint64(field): if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length: if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output) let vres = pb.getValue(header, output)
case r if vres.isOk():
of ProtoResult.NoError:
res = true res = true
of ProtoResult.BufferOverflowError:
# Buffer overflow error is not critical error, we still can get
# field values with proper size.
discard
else: else:
output.setLen(0) output.setLen(0)
return false return err(vres.error)
else: else:
# We are ignoring wire types different from ProtoFieldKind.Length, # We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working. # because it is how `protoc` is working.
if not(skipValue(pb, header)): let sres = pb.skipValue(header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
else: else:
if not(skipValue(pb, header)): let sres = pb.skipValue(header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
if res:
res ok(true)
proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {.
inline.} =
var buffer: seq[byte]
if pb.getField(field, buffer):
output = initProtoBuffer(buffer)
true
else: else:
false ok(false)
proc getField*(pb: ProtoBuffer, field: int,
output: var ProtoBuffer): ProtoResult[bool] {.inline.} =
var buffer: seq[byte]
let res = pb.getField(field, buffer)
if res.isOk():
if res.get():
output = initProtoBuffer(buffer)
ok(true)
else:
ok(false)
else:
err(res.error)
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int, proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var seq[T]): bool = output: var seq[T]): ProtoResult[bool] =
checkFieldNumber(field) checkFieldNumber(field)
var pb = data var pb = data
output.setLen(0) output.setLen(0)
while not(pb.isEmpty()): while not(pb.isEmpty()):
var header: ProtoHeader var header: ProtoHeader
if not(pb.getHeader(header)): let hres = pb.getHeader(header)
if hres.isErr():
output.setLen(0) output.setLen(0)
return false return err(hres.error)
if header.index == uint64(field): if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length: if header.wire == ProtoFieldKind.Length:
var item: T var item: T
let r = getValue(pb, header, item) let vres = pb.getValue(header, item)
case r if vres.isOk():
of ProtoResult.NoError:
output.add(item) output.add(item)
else: else:
output.setLen(0) output.setLen(0)
return false return err(vres.error)
else: else:
if not(skipValue(pb, header)): let sres = pb.skipValue(header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
else: else:
if not(skipValue(pb, header)): let sres = pb.skipValue(header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
if len(output) > 0: if len(output) > 0:
true ok(true)
else: else:
false ok(false)
proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer, proc getRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
field: int, output: var seq[T]): ProtoResult[bool] =
output: var seq[T]): bool =
checkFieldNumber(field) checkFieldNumber(field)
var pb = data var pb = data
output.setLen(0) output.setLen(0)
while not(pb.isEmpty()): while not(pb.isEmpty()):
var header: ProtoHeader var header: ProtoHeader
if not(pb.getHeader(header)): let hres = pb.getHeader(header)
if hres.isErr():
output.setLen(0) output.setLen(0)
return false return err(hres.error)
if header.index == uint64(field): if header.index == uint64(field):
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32, if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
ProtoFieldKind.Fixed64}: ProtoFieldKind.Fixed64}:
var item: T var item: T
let r = getValue(pb, header, item) let vres = getValue(pb, header, item)
case r if vres.isOk():
of ProtoResult.NoError:
output.add(item) output.add(item)
else: else:
output.setLen(0) output.setLen(0)
return false return err(vres.error)
else: else:
if not(skipValue(pb, header)): let sres = skipValue(pb, header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
else: else:
if not(skipValue(pb, header)): let sres = skipValue(pb, header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
if len(output) > 0: if len(output) > 0:
true ok(true)
else: else:
false ok(false)
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int, proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var seq[T]): bool = output: var seq[T]): ProtoResult[bool] =
checkFieldNumber(field) checkFieldNumber(field)
var pb = data var pb = data
output.setLen(0) output.setLen(0)
while not(pb.isEmpty()): while not(pb.isEmpty()):
var header: ProtoHeader var header: ProtoHeader
if not(pb.getHeader(header)): let hres = pb.getHeader(header)
if hres.isErr():
output.setLen(0) output.setLen(0)
return false return err(hres.error)
if header.index == uint64(field): if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length: if header.wire == ProtoFieldKind.Length:
var arritem: seq[byte] var arritem: seq[byte]
let rarr = getValue(pb, header, arritem) let ares = getValue(pb, header, arritem)
case rarr if ares.isOk():
of ProtoResult.NoError:
var pbarr = initProtoBuffer(arritem) var pbarr = initProtoBuffer(arritem)
let itemHeader = let itemHeader =
when (T is uint64) or (T is uint32) or (T is uint) or when (T is uint64) or (T is uint32) or (T is uint) or
@ -751,29 +763,30 @@ proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
ProtoHeader(wire: ProtoFieldKind.Fixed64) ProtoHeader(wire: ProtoFieldKind.Fixed64)
while not(pbarr.isEmpty()): while not(pbarr.isEmpty()):
var item: T var item: T
let res = getValue(pbarr, itemHeader, item) let vres = getValue(pbarr, itemHeader, item)
case res if vres.isOk():
of ProtoResult.NoError:
output.add(item) output.add(item)
else: else:
output.setLen(0) output.setLen(0)
return false return err(vres.error)
else: else:
output.setLen(0) output.setLen(0)
return false return err(ares.error)
else: else:
if not(skipValue(pb, header)): let sres = skipValue(pb, header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
else: else:
if not(skipValue(pb, header)): let sres = skipValue(pb, header)
if sres.isErr():
output.setLen(0) output.setLen(0)
return false return err(sres.error)
if len(output) > 0: if len(output) > 0:
true ok(true)
else: else:
false ok(false)
proc getVarintValue*(data: var ProtoBuffer, field: int, proc getVarintValue*(data: var ProtoBuffer, field: int,
value: var SomeVarint): int {.deprecated.} = value: var SomeVarint): int {.deprecated.} =

View File

@ -46,52 +46,56 @@ type
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer = proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
result = initProtoBuffer() result = initProtoBuffer()
result.write(1, peerInfo.publicKey.get().getBytes().tryGet()) result.write(1, peerInfo.publicKey.get().getBytes().tryGet())
for ma in peerInfo.addrs: for ma in peerInfo.addrs:
result.write(2, ma.data.buffer) result.write(2, ma.data.buffer)
for proto in peerInfo.protocols: for proto in peerInfo.protocols:
result.write(3, proto) result.write(3, proto)
result.write(4, observedAddr.data.buffer) result.write(4, observedAddr.data.buffer)
let protoVersion = ProtoVersion let protoVersion = ProtoVersion
result.write(5, protoVersion) result.write(5, protoVersion)
let agentVersion = AgentVersion let agentVersion = AgentVersion
result.write(6, agentVersion) result.write(6, agentVersion)
result.finish() result.finish()
proc decodeMsg*(buf: seq[byte]): IdentifyInfo = proc decodeMsg*(buf: seq[byte]): Option[IdentifyInfo] =
var
iinfo: IdentifyInfo
pubKey: PublicKey
oaddr: MultiAddress
protoVersion: string
agentVersion: string
var pb = initProtoBuffer(buf) var pb = initProtoBuffer(buf)
var pubKey: PublicKey let r1 = pb.getField(1, pubKey)
if pb.getField(1, pubKey): let r2 = pb.getRepeatedField(2, iinfo.addrs)
trace "read public key from message", pubKey = ($pubKey).shortLog let r3 = pb.getRepeatedField(3, iinfo.protos)
result.pubKey = some(pubKey) let r4 = pb.getField(4, oaddr)
let r5 = pb.getField(5, protoVersion)
let r6 = pb.getField(6, agentVersion)
if pb.getRepeatedField(2, result.addrs): let res = r1.isOk() and r2.isOk() and r3.isOk() and
trace "read addresses from message", addresses = result.addrs r4.isOk() and r5.isOk() and r6.isOk()
if pb.getRepeatedField(3, result.protos): if res:
trace "read protos from message", protocols = result.protos if r1.get():
iinfo.pubKey = some(pubKey)
var observableAddr: MultiAddress if r4.get():
if pb.getField(4, observableAddr): iinfo.observedAddr = some(oaddr)
trace "read observableAddr from message", address = observableAddr if r5.get():
result.observedAddr = some(observableAddr) iinfo.protoVersion = some(protoVersion)
if r6.get():
var protoVersion = "" iinfo.agentVersion = some(agentVersion)
if pb.getField(5, protoVersion): trace "decodeMsg: decoded message", pubkey = ($pubKey).shortLog,
trace "read protoVersion from message", protoVersion = protoVersion addresses = $iinfo.addrs, protocols = $iinfo.protos,
result.protoVersion = some(protoVersion) observable_address = $iinfo.observedAddr,
proto_version = $iinfo.protoVersion,
var agentVersion = "" agent_version = $iinfo.agentVersion
if pb.getField(6, agentVersion): some(iinfo)
trace "read agentVersion from message", agentVersion = agentVersion else:
result.agentVersion = some(agentVersion) trace "decodeMsg: failed to decode received message"
none[IdentifyInfo]()
proc newIdentify*(peerInfo: PeerInfo): Identify = proc newIdentify*(peerInfo: PeerInfo): Identify =
new result new result
@ -122,11 +126,13 @@ proc identify*(p: Identify,
trace "initiating identify", peer = $conn trace "initiating identify", peer = $conn
var message = await conn.readLp(64*1024) var message = await conn.readLp(64*1024)
if len(message) == 0: if len(message) == 0:
trace "identify: Invalid or empty message received!" trace "identify: Empty message received!"
raise newException(IdentityInvalidMsgError, raise newException(IdentityInvalidMsgError, "Empty message received!")
"Invalid or empty message received!")
result = decodeMsg(message) let infoOpt = decodeMsg(message)
if infoOpt.isNone():
raise newException(IdentityInvalidMsgError, "Incorrect message received!")
result = infoOpt.get()
if not isNil(remotePeerInfo) and result.pubKey.isSome: if not isNil(remotePeerInfo) and result.pubKey.isSome:
let peer = PeerID.init(result.pubKey.get()) let peer = PeerID.init(result.pubKey.get())

View File

@ -43,7 +43,7 @@ type
RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.} RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.}
func hash*(p: PubSubPeer): Hash = func hash*(p: PubSubPeer): Hash =
# int is either 32/64, so intptr basically, pubsubpeer is a ref # int is either 32/64, so intptr basically, pubsubpeer is a ref
cast[pointer](p).hash cast[pointer](p).hash
@ -114,7 +114,13 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
trace "message already received, skipping", peer = p.id trace "message already received, skipping", peer = p.id
continue continue
var msg = decodeRpcMsg(data) var rmsg = decodeRpcMsg(data)
if rmsg.isErr():
notice "failed to decode msg from peer", peer = p.id
break
var msg = rmsg.get()
trace "decoded msg from peer", peer = p.id, msg = msg.shortLog trace "decoded msg from peer", peer = p.id, msg = msg.shortLog
# trigger hooks # trigger hooks
p.recvObservers(msg) p.recvObservers(msg)
@ -149,11 +155,11 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
p.sendObservers(mm) p.sendObservers(mm)
let encoded = encodeRpcMsg(mm) let encoded = encodeRpcMsg(mm)
if encoded.buffer.len <= 0: if encoded.len <= 0:
trace "empty message, skipping", peer = p.id trace "empty message, skipping", peer = p.id
return return
let digest = $(sha256.digest(encoded.buffer)) let digest = $(sha256.digest(encoded))
if digest in p.sentRpcCache: if digest in p.sentRpcCache:
trace "message already sent to peer, skipping", peer = p.id trace "message already sent to peer, skipping", peer = p.id
libp2p_pubsub_skipped_sent_messages.inc(labelValues = [p.id]) libp2p_pubsub_skipped_sent_messages.inc(labelValues = [p.id])
@ -164,8 +170,8 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
encoded = digest encoded = digest
if p.connected: # this can happen if the remote disconnected if p.connected: # this can happen if the remote disconnected
trace "sending encoded msgs to peer", peer = p.id, trace "sending encoded msgs to peer", peer = p.id,
encoded = encoded.buffer.shortLog encoded = encoded.shortLog
await p.sendConn.writeLp(encoded.buffer) await p.sendConn.writeLp(encoded)
p.sentRpcCache.put(digest) p.sentRpcCache.put(digest)
for m in msgs: for m in msgs:

View File

@ -80,163 +80,151 @@ proc encodeMessage*(msg: Message): seq[byte] =
proc write*(pb: var ProtoBuffer, field: int, msg: Message) = proc write*(pb: var ProtoBuffer, field: int, msg: Message) =
pb.write(field, encodeMessage(msg)) pb.write(field, encodeMessage(msg))
proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} = proc decodeGraft*(pb: ProtoBuffer): ProtoResult[ControlGraft] {.
inline.} =
trace "decodeGraft: decoding message" trace "decodeGraft: decoding message"
var control = ControlGraft() var control = ControlGraft()
var topicId: string if ? pb.getField(1, control.topicId):
if pb.getField(1, topicId): trace "decodeGraft: read topicId", topic_id = control.topicId
control.topicId = topicId
trace "decodeGraft: read topicId", topic_id = topicId
else: else:
trace "decodeGraft: topicId is missing" trace "decodeGraft: topicId is missing"
control ok(control)
proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} = proc decodePrune*(pb: ProtoBuffer): ProtoResult[ControlPrune] {.
inline.} =
trace "decodePrune: decoding message" trace "decodePrune: decoding message"
var control = ControlPrune() var control = ControlPrune()
var topicId: string if ? pb.getField(1, control.topicId):
if pb.getField(1, topicId): trace "decodePrune: read topicId", topic_id = control.topicId
control.topicId = topicId
trace "decodePrune: read topicId", topic_id = topicId
else: else:
trace "decodePrune: topicId is missing" trace "decodePrune: topicId is missing"
control ok(control)
proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} = proc decodeIHave*(pb: ProtoBuffer): ProtoResult[ControlIHave] {.
inline.} =
trace "decodeIHave: decoding message" trace "decodeIHave: decoding message"
var control = ControlIHave() var control = ControlIHave()
var topicId: string if ? pb.getField(1, control.topicId):
if pb.getField(1, topicId): trace "decodeIHave: read topicId", topic_id = control.topicId
control.topicId = topicId
trace "decodeIHave: read topicId", topic_id = topicId
else: else:
trace "decodeIHave: topicId is missing" trace "decodeIHave: topicId is missing"
if pb.getRepeatedField(2, control.messageIDs): if ? pb.getRepeatedField(2, control.messageIDs):
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
else: else:
trace "decodeIHave: no messageIDs" trace "decodeIHave: no messageIDs"
control ok(control)
proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} = proc decodeIWant*(pb: ProtoBuffer): ProtoResult[ControlIWant] {.inline.} =
trace "decodeIWant: decoding message" trace "decodeIWant: decoding message"
var control = ControlIWant() var control = ControlIWant()
if pb.getRepeatedField(1, control.messageIDs): if ? pb.getRepeatedField(1, control.messageIDs):
trace "decodeIWant: read messageIDs", message_ids = control.messageIDs trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
else: else:
trace "decodeIWant: no messageIDs" trace "decodeIWant: no messageIDs"
ok(control)
proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} = proc decodeControl*(pb: ProtoBuffer): ProtoResult[Option[ControlMessage]] {.
inline.} =
trace "decodeControl: decoding message" trace "decodeControl: decoding message"
var buffer: seq[byte] var buffer: seq[byte]
if pb.getField(3, buffer): if ? pb.getField(3, buffer):
var control: ControlMessage var control: ControlMessage
var cpb = initProtoBuffer(buffer) var cpb = initProtoBuffer(buffer)
var ihavepbs: seq[seq[byte]] var ihavepbs: seq[seq[byte]]
var iwantpbs: seq[seq[byte]] var iwantpbs: seq[seq[byte]]
var graftpbs: seq[seq[byte]] var graftpbs: seq[seq[byte]]
var prunepbs: seq[seq[byte]] var prunepbs: seq[seq[byte]]
if ? cpb.getRepeatedField(1, ihavepbs):
discard cpb.getRepeatedField(1, ihavepbs) for item in ihavepbs:
discard cpb.getRepeatedField(2, iwantpbs) control.ihave.add(? decodeIHave(initProtoBuffer(item)))
discard cpb.getRepeatedField(3, graftpbs) if ? cpb.getRepeatedField(2, iwantpbs):
discard cpb.getRepeatedField(4, prunepbs) for item in iwantpbs:
control.iwant.add(? decodeIWant(initProtoBuffer(item)))
for item in ihavepbs: if ? cpb.getRepeatedField(3, graftpbs):
control.ihave.add(decodeIHave(initProtoBuffer(item))) for item in graftpbs:
for item in iwantpbs: control.graft.add(? decodeGraft(initProtoBuffer(item)))
control.iwant.add(decodeIWant(initProtoBuffer(item))) if ? cpb.getRepeatedField(4, prunepbs):
for item in graftpbs: for item in prunepbs:
control.graft.add(decodeGraft(initProtoBuffer(item))) control.prune.add(? decodePrune(initProtoBuffer(item)))
for item in prunepbs: trace "decodeControl: message statistics", graft_count = len(control.graft),
control.prune.add(decodePrune(initProtoBuffer(item))) prune_count = len(control.prune),
ihave_count = len(control.ihave),
trace "decodeControl: " iwant_count = len(control.iwant)
some(control) ok(some(control))
else: else:
none[ControlMessage]() ok(none[ControlMessage]())
proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} = proc decodeSubscription*(pb: ProtoBuffer): ProtoResult[SubOpts] {.inline.} =
trace "decodeSubscription: decoding message" trace "decodeSubscription: decoding message"
var subflag: uint64 var subflag: uint64
var sub = SubOpts() var sub = SubOpts()
if pb.getField(1, subflag): if ? pb.getField(1, subflag):
sub.subscribe = bool(subflag) sub.subscribe = bool(subflag)
trace "decodeSubscription: read subscribe", subscribe = subflag trace "decodeSubscription: read subscribe", subscribe = subflag
else: else:
trace "decodeSubscription: subscribe is missing" trace "decodeSubscription: subscribe is missing"
if pb.getField(2, sub.topic): if ? pb.getField(2, sub.topic):
trace "decodeSubscription: read topic", topic = sub.topic trace "decodeSubscription: read topic", topic = sub.topic
else: else:
trace "decodeSubscription: topic is missing" trace "decodeSubscription: topic is missing"
ok(sub)
sub proc decodeSubscriptions*(pb: ProtoBuffer): ProtoResult[seq[SubOpts]] {.
inline.} =
proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} =
trace "decodeSubscriptions: decoding message" trace "decodeSubscriptions: decoding message"
var subpbs: seq[seq[byte]] var subpbs: seq[seq[byte]]
var subs: seq[SubOpts] var subs: seq[SubOpts]
if pb.getRepeatedField(1, subpbs): let res = ? pb.getRepeatedField(1, subpbs)
if res:
trace "decodeSubscriptions: read subscriptions", count = len(subpbs) trace "decodeSubscriptions: read subscriptions", count = len(subpbs)
for item in subpbs: for item in subpbs:
let sub = decodeSubscription(initProtoBuffer(item)) subs.add(? decodeSubscription(initProtoBuffer(item)))
subs.add(sub) if len(subs) == 0:
trace "decodeSubscription: no subscriptions found"
ok(subs)
if len(subs) == 0: proc decodeMessage*(pb: ProtoBuffer): ProtoResult[Message] {.inline.} =
trace "decodeSubscription: no subscriptions found"
subs
proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} =
trace "decodeMessage: decoding message" trace "decodeMessage: decoding message"
var msg: Message var msg: Message
if pb.getField(1, msg.fromPeer): if ? pb.getField(1, msg.fromPeer):
trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty() trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty()
else: else:
trace "decodeMessage: fromPeer is missing" trace "decodeMessage: fromPeer is missing"
if ? pb.getField(2, msg.data):
if pb.getField(2, msg.data):
trace "decodeMessage: read data", data = msg.data.shortLog() trace "decodeMessage: read data", data = msg.data.shortLog()
else: else:
trace "decodeMessage: data is missing" trace "decodeMessage: data is missing"
if ? pb.getField(3, msg.seqno):
if pb.getField(3, msg.seqno):
trace "decodeMessage: read seqno", seqno = msg.data.shortLog() trace "decodeMessage: read seqno", seqno = msg.data.shortLog()
else: else:
trace "decodeMessage: seqno is missing" trace "decodeMessage: seqno is missing"
if ? pb.getRepeatedField(4, msg.topicIDs):
if pb.getRepeatedField(4, msg.topicIDs):
trace "decodeMessage: read topics", topic_ids = msg.topicIDs trace "decodeMessage: read topics", topic_ids = msg.topicIDs
else: else:
trace "decodeMessage: topics are missing" trace "decodeMessage: topics are missing"
if ? pb.getField(5, msg.signature):
if pb.getField(5, msg.signature):
trace "decodeMessage: read signature", signature = msg.signature.shortLog() trace "decodeMessage: read signature", signature = msg.signature.shortLog()
else: else:
trace "decodeMessage: signature is missing" trace "decodeMessage: signature is missing"
if ? pb.getField(6, msg.key):
if pb.getField(6, msg.key):
trace "decodeMessage: read public key", key = msg.key.shortLog() trace "decodeMessage: read public key", key = msg.key.shortLog()
else: else:
trace "decodeMessage: public key is missing" trace "decodeMessage: public key is missing"
ok(msg)
msg proc decodeMessages*(pb: ProtoBuffer): ProtoResult[seq[Message]] {.inline.} =
proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} =
trace "decodeMessages: decoding message" trace "decodeMessages: decoding message"
var msgpbs: seq[seq[byte]] var msgpbs: seq[seq[byte]]
var msgs: seq[Message] var msgs: seq[Message]
if pb.getRepeatedField(2, msgpbs): if ? pb.getRepeatedField(2, msgpbs):
trace "decodeMessages: read messages", count = len(msgpbs) trace "decodeMessages: read messages", count = len(msgpbs)
for item in msgpbs: for item in msgpbs:
let msg = decodeMessage(initProtoBuffer(item)) msgs.add(? decodeMessage(initProtoBuffer(item)))
msgs.add(msg) else:
if len(msgs) == 0:
trace "decodeMessages: no messages found" trace "decodeMessages: no messages found"
ok(msgs)
msgs proc encodeRpcMsg*(msg: RPCMsg): seq[byte] =
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
trace "encodeRpcMsg: encoding message", msg = msg.shortLog() trace "encodeRpcMsg: encoding message", msg = msg.shortLog()
var pb = initProtoBuffer() var pb = initProtoBuffer()
for item in msg.subscriptions: for item in msg.subscriptions:
@ -247,14 +235,13 @@ proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
pb.write(3, msg.control.get()) pb.write(3, msg.control.get())
if len(pb.buffer) > 0: if len(pb.buffer) > 0:
pb.finish() pb.finish()
result = pb pb.buffer
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg = proc decodeRpcMsg*(msg: seq[byte]): ProtoResult[RPCMsg] {.inline.} =
trace "decodeRpcMsg: decoding message", msg = msg.shortLog() trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
var pb = initProtoBuffer(msg) var pb = initProtoBuffer(msg)
var rpcMsg: RPCMsg var rpcMsg: RPCMsg
rpcMsg.messages = pb.decodeMessages() rpcMsg.messages = ? pb.decodeMessages()
rpcMsg.subscriptions = pb.decodeSubscriptions() rpcMsg.subscriptions = ? pb.decodeSubscriptions()
rpcMsg.control = pb.decodeControl() rpcMsg.control = ? pb.decodeControl()
ok(rpcMsg)
rpcMsg

View File

@ -449,9 +449,11 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
remoteSig: Signature remoteSig: Signature
remoteSigBytes: seq[byte] remoteSigBytes: seq[byte]
if not(remoteProof.getField(1, remotePubKeyBytes)): let r1 = remoteProof.getField(1, remotePubKeyBytes)
let r2 = remoteProof.getField(2, remoteSigBytes)
if r1.isErr() or not(r1.get()):
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not(remoteProof.getField(2, remoteSigBytes)): if r2.isErr() or not(r2.get()):
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remotePubKey.init(remotePubKeyBytes): if not remotePubKey.init(remotePubKeyBytes):

View File

@ -88,7 +88,7 @@ suite "MinProtobuf test suite":
var value: uint64 var value: uint64
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value) let res = pb.getField(1, value)
doAssert(res) doAssert(res.isOk() == true and res.get() == true)
value value
proc getFixed32EncodedValue(value: float32): seq[byte] = proc getFixed32EncodedValue(value: float32): seq[byte] =
@ -101,7 +101,7 @@ suite "MinProtobuf test suite":
var value: float32 var value: float32
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value) let res = pb.getField(1, value)
doAssert(res) doAssert(res.isOk() == true and res.get() == true)
cast[uint32](value) cast[uint32](value)
proc getFixed64EncodedValue(value: float64): seq[byte] = proc getFixed64EncodedValue(value: float64): seq[byte] =
@ -114,7 +114,7 @@ suite "MinProtobuf test suite":
var value: float64 var value: float64
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value) let res = pb.getField(1, value)
doAssert(res) doAssert(res.isOk() == true and res.get() == true)
cast[uint64](value) cast[uint64](value)
proc getLengthEncodedValue(value: string): seq[byte] = proc getLengthEncodedValue(value: string): seq[byte] =
@ -134,8 +134,7 @@ suite "MinProtobuf test suite":
var valueLen = 0 var valueLen = 0
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value, valueLen) let res = pb.getField(1, value, valueLen)
doAssert(res.isOk() == true and res.get() == true)
doAssert(res)
value.setLen(valueLen) value.setLen(valueLen)
value value
@ -173,17 +172,19 @@ suite "MinProtobuf test suite":
# corrupting # corrupting
data.setLen(len(data) - 1) data.setLen(len(data) - 1)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isErr() == true
test "[varint] non-existent field test": test "[varint] non-existent field test":
for i in 0 ..< len(VarintValues): for i in 0 ..< len(VarintValues):
var value: uint64 var value: uint64
var data = getVarintEncodedValue(VarintValues[i]) var data = getVarintEncodedValue(VarintValues[i])
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(2, value)
check: check:
pb.getField(2, value) == false res.isOk() == true
value == 0'u64 res.get() == false
test "[varint] corrupted header test": test "[varint] corrupted header test":
for i in 0 ..< len(VarintValues): for i in 0 ..< len(VarintValues):
@ -192,15 +193,17 @@ suite "MinProtobuf test suite":
var data = getVarintEncodedValue(VarintValues[i]) var data = getVarintEncodedValue(VarintValues[i])
data.corruptHeader(k) data.corruptHeader(k)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isErr() == true
test "[varint] empty buffer test": test "[varint] empty buffer test":
var value: uint64 var value: uint64
var pb = initProtoBuffer() var pb = initProtoBuffer()
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isOk() == true
value == 0'u64 res.get() == false
test "[varint] Repeated field test": test "[varint] Repeated field test":
var pb1 = initProtoBuffer() var pb1 = initProtoBuffer()
@ -218,9 +221,12 @@ suite "MinProtobuf test suite":
let r2 = pb2.getRepeatedField(2, fieldarr2) let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3) let r3 = pb2.getRepeatedField(3, fieldarr3)
check: check:
r1 == true r1.isOk() == true
r2 == true r2.isOk() == true
r3 == false r3.isOk() == true
r1.get() == true
r2.get() == true
r3.get() == false
len(fieldarr3) == 0 len(fieldarr3) == 0
len(fieldarr2) == 1 len(fieldarr2) == 1
len(fieldarr1) == 4 len(fieldarr1) == 4
@ -246,9 +252,12 @@ suite "MinProtobuf test suite":
let r2 = pb2.getPackedRepeatedField(2, fieldarr2) let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3) let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check: check:
r1 == true r1.isOk() == true
r2 == true r2.isOk() == true
r3 == false r3.isOk() == true
r1.get() == true
r2.get() == true
r3.get() == false
len(fieldarr3) == 0 len(fieldarr3) == 0
len(fieldarr2) == 2 len(fieldarr2) == 2
len(fieldarr1) == 6 len(fieldarr1) == 6
@ -284,17 +293,19 @@ suite "MinProtobuf test suite":
# corrupting # corrupting
data.setLen(len(data) - 1) data.setLen(len(data) - 1)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isErr() == true
test "[fixed32] non-existent field test": test "[fixed32] non-existent field test":
for i in 0 ..< len(Fixed32Values): for i in 0 ..< len(Fixed32Values):
var value: float32 var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i])) var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(2, value)
check: check:
pb.getField(2, value) == false res.isOk() == true
value == float32(0) res.get() == false
test "[fixed32] corrupted header test": test "[fixed32] corrupted header test":
for i in 0 ..< len(Fixed32Values): for i in 0 ..< len(Fixed32Values):
@ -303,15 +314,17 @@ suite "MinProtobuf test suite":
var data = getFixed32EncodedValue(float32(Fixed32Values[i])) var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
data.corruptHeader(k) data.corruptHeader(k)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isErr() == true
test "[fixed32] empty buffer test": test "[fixed32] empty buffer test":
var value: float32 var value: float32
var pb = initProtoBuffer() var pb = initProtoBuffer()
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isOk() == true
value == float32(0) res.get() == false
test "[fixed32] Repeated field test": test "[fixed32] Repeated field test":
var pb1 = initProtoBuffer() var pb1 = initProtoBuffer()
@ -329,9 +342,12 @@ suite "MinProtobuf test suite":
let r2 = pb2.getRepeatedField(2, fieldarr2) let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3) let r3 = pb2.getRepeatedField(3, fieldarr3)
check: check:
r1 == true r1.isOk() == true
r2 == true r2.isOk() == true
r3 == false r3.isOk() == true
r1.get() == true
r2.get() == true
r3.get() == false
len(fieldarr3) == 0 len(fieldarr3) == 0
len(fieldarr2) == 1 len(fieldarr2) == 1
len(fieldarr1) == 4 len(fieldarr1) == 4
@ -360,9 +376,12 @@ suite "MinProtobuf test suite":
let r2 = pb2.getPackedRepeatedField(2, fieldarr2) let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3) let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check: check:
r1 == true r1.isOk() == true
r2 == true r2.isOk() == true
r3 == false r3.isOk() == true
r1.get() == true
r2.get() == true
r3.get() == false
len(fieldarr3) == 0 len(fieldarr3) == 0
len(fieldarr2) == 2 len(fieldarr2) == 2
len(fieldarr1) == 5 len(fieldarr1) == 5
@ -397,17 +416,19 @@ suite "MinProtobuf test suite":
# corrupting # corrupting
data.setLen(len(data) - 1) data.setLen(len(data) - 1)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isErr() == true
test "[fixed64] non-existent field test": test "[fixed64] non-existent field test":
for i in 0 ..< len(Fixed64Values): for i in 0 ..< len(Fixed64Values):
var value: float64 var value: float64
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(2, value)
check: check:
pb.getField(2, value) == false res.isOk() == true
value == float64(0) res.get() == false
test "[fixed64] corrupted header test": test "[fixed64] corrupted header test":
for i in 0 ..< len(Fixed64Values): for i in 0 ..< len(Fixed64Values):
@ -416,15 +437,17 @@ suite "MinProtobuf test suite":
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
data.corruptHeader(k) data.corruptHeader(k)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isErr() == true
test "[fixed64] empty buffer test": test "[fixed64] empty buffer test":
var value: float64 var value: float64
var pb = initProtoBuffer() var pb = initProtoBuffer()
let res = pb.getField(1, value)
check: check:
pb.getField(1, value) == false res.isOk() == true
value == float64(0) res.get() == false
test "[fixed64] Repeated field test": test "[fixed64] Repeated field test":
var pb1 = initProtoBuffer() var pb1 = initProtoBuffer()
@ -442,9 +465,12 @@ suite "MinProtobuf test suite":
let r2 = pb2.getRepeatedField(2, fieldarr2) let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3) let r3 = pb2.getRepeatedField(3, fieldarr3)
check: check:
r1 == true r1.isOk() == true
r2 == true r2.isOk() == true
r3 == false r3.isOk() == true
r1.get() == true
r2.get() == true
r3.get() == false
len(fieldarr3) == 0 len(fieldarr3) == 0
len(fieldarr2) == 1 len(fieldarr2) == 1
len(fieldarr1) == 4 len(fieldarr1) == 4
@ -474,9 +500,12 @@ suite "MinProtobuf test suite":
let r2 = pb2.getPackedRepeatedField(2, fieldarr2) let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3) let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check: check:
r1 == true r1.isOk() == true
r2 == true r2.isOk() == true
r3 == false r3.isOk() == true
r1.get() == true
r2.get() == true
r3.get() == false
len(fieldarr3) == 0 len(fieldarr3) == 0
len(fieldarr2) == 2 len(fieldarr2) == 2
len(fieldarr1) == 8 len(fieldarr1) == 8
@ -523,8 +552,9 @@ suite "MinProtobuf test suite":
# corrupting # corrupting
data.setLen(len(data) - 1) data.setLen(len(data) - 1)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value, valueLen)
check: check:
pb.getField(1, value, valueLen) == false res.isErr() == true
test "[length] non-existent field test": test "[length] non-existent field test":
for i in 0 ..< len(LengthValues): for i in 0 ..< len(LengthValues):
@ -532,8 +562,10 @@ suite "MinProtobuf test suite":
var valueLen = 0 var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i]) var data = getLengthEncodedValue(LengthValues[i])
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(2, value, valueLen)
check: check:
pb.getField(2, value, valueLen) == false res.isOk() == true
res.get() == false
valueLen == 0 valueLen == 0
test "[length] corrupted header test": test "[length] corrupted header test":
@ -544,15 +576,18 @@ suite "MinProtobuf test suite":
var data = getLengthEncodedValue(LengthValues[i]) var data = getLengthEncodedValue(LengthValues[i])
data.corruptHeader(k) data.corruptHeader(k)
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value, valueLen)
check: check:
pb.getField(1, value, valueLen) == false res.isErr() == true
test "[length] empty buffer test": test "[length] empty buffer test":
var value = newSeq[byte](len(LengthValues[0])) var value = newSeq[byte](len(LengthValues[0]))
var valueLen = 0 var valueLen = 0
var pb = initProtoBuffer() var pb = initProtoBuffer()
let res = pb.getField(1, value, valueLen)
check: check:
pb.getField(1, value, valueLen) == false res.isOk() == true
res.get() == false
valueLen == 0 valueLen == 0
test "[length] buffer overflow test": test "[length] buffer overflow test":
@ -562,8 +597,10 @@ suite "MinProtobuf test suite":
var value = newString(len(LengthValues[i]) - 1) var value = newString(len(LengthValues[i]) - 1)
var valueLen = 0 var valueLen = 0
var pb = initProtoBuffer(data) var pb = initProtoBuffer(data)
let res = pb.getField(1, value, valueLen)
check: check:
pb.getField(1, value, valueLen) == false res.isOk() == true
res.get() == false
valueLen == len(LengthValues[i]) valueLen == len(LengthValues[i])
isFullZero(value) == true isFullZero(value) == true
@ -578,8 +615,10 @@ suite "MinProtobuf test suite":
var pb2 = initProtoBuffer(pb1.buffer) var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(4) var value = newString(4)
var valueLen = 0 var valueLen = 0
let res = pb2.getField(1, value, valueLen)
check: check:
pb2.getField(1, value, valueLen) == true res.isOk() == true
res.get() == true
value == "SOME" value == "SOME"
test "[length] too big message test": test "[length] too big message test":
@ -593,8 +632,9 @@ suite "MinProtobuf test suite":
var pb2 = initProtoBuffer(pb1.buffer) var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(MaxMessageSize + 1) var value = newString(MaxMessageSize + 1)
var valueLen = 0 var valueLen = 0
let res = pb2.getField(1, value, valueLen)
check: check:
pb2.getField(1, value, valueLen) == false res.isErr() == true
test "[length] Repeated field test": test "[length] Repeated field test":
var pb1 = initProtoBuffer() var pb1 = initProtoBuffer()
@ -612,9 +652,12 @@ suite "MinProtobuf test suite":
let r2 = pb2.getRepeatedField(2, fieldarr2) let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3) let r3 = pb2.getRepeatedField(3, fieldarr3)
check: check:
r1 == true r1.isOk() == true
r2 == true r2.isOk() == true
r3 == false r3.isOk() == true
r1.get() == true
r2.get() == true
r3.get() == false
len(fieldarr3) == 0 len(fieldarr3) == 0
len(fieldarr2) == 1 len(fieldarr2) == 1
len(fieldarr1) == 4 len(fieldarr1) == 4
@ -662,11 +705,16 @@ suite "MinProtobuf test suite":
var lengthValue = newString(10) var lengthValue = newString(10)
var lengthSize: int var lengthSize: int
let r1 = pb.getField(1, varintValue)
let r2 = pb.getField(2, fixed32Value)
let r3 = pb.getField(3, fixed64Value)
let r4 = pb.getField(4, lengthValue, lengthSize)
check: check:
pb.getField(1, varintValue) == true r1.isOk() == true
pb.getField(2, fixed32Value) == true r2.isOk() == true
pb.getField(3, fixed64Value) == true r3.isOk() == true
pb.getField(4, lengthValue, lengthSize) == true r4.isOk() == true
lengthValue.setLen(lengthSize) lengthValue.setLen(lengthSize)