Merge branch 'master' into gossip-one-one
This commit is contained in:
commit
8078fec0f0
15
.travis.yml
15
.travis.yml
|
@ -22,13 +22,14 @@ matrix:
|
|||
- NPROC=2
|
||||
before_install:
|
||||
- export GOPATH=$HOME/go
|
||||
- os: linux
|
||||
arch: arm64
|
||||
env:
|
||||
- NPROC=6 # Worth trying more than 2 parallel jobs: https://travis-ci.community/t/no-cache-support-on-arm64/5416/8
|
||||
# (also used to get a different cache key than the amd64 one)
|
||||
before_install:
|
||||
- export GOPATH=$HOME/go
|
||||
## arm64 is very unreliable and slow, disabled for now
|
||||
# - os: linux
|
||||
# arch: arm64
|
||||
# env:
|
||||
# - NPROC=6 # Worth trying more than 2 parallel jobs: https://travis-ci.community/t/no-cache-support-on-arm64/5416/8
|
||||
# # (also used to get a different cache key than the amd64 one)
|
||||
# before_install:
|
||||
# - export GOPATH=$HOME/go
|
||||
- os: osx
|
||||
env:
|
||||
- NPROC=2
|
||||
|
|
|
@ -17,12 +17,15 @@ requires "nim >= 1.2.0",
|
|||
"stew >= 0.1.0"
|
||||
|
||||
proc runTest(filename: string, verify: bool = true, sign: bool = true) =
|
||||
var excstr = "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info"
|
||||
var excstr = "nim c --opt:speed -d:debug --verbosity:0 --hints:off"
|
||||
excstr.add(" --warning[CaseTransition]:off --warning[ObservableStores]:off --warning[LockLevel]:off")
|
||||
excstr.add(" -d:libp2p_pubsub_sign=" & $sign)
|
||||
excstr.add(" -d:libp2p_pubsub_verify=" & $verify)
|
||||
excstr.add(" tests/" & filename)
|
||||
exec excstr
|
||||
if verify and sign:
|
||||
# build it with TRACE and JSON logs
|
||||
exec excstr & " -d:chronicles_log_level=TRACE -d:chronicles_sinks:json" & " tests/" & filename
|
||||
# build it again, to run it with less verbose logs
|
||||
exec excstr & " -d:chronicles_log_level=INFO -r" & " tests/" & filename
|
||||
rmFile "tests/" & filename.toExe
|
||||
|
||||
proc buildSample(filename: string) =
|
||||
|
|
|
@ -222,8 +222,8 @@ proc toBytes*(key: PrivateKey, data: var openarray[byte]): CryptoResult[int] =
|
|||
##
|
||||
## Returns number of bytes (octets) needed to store private key ``key``.
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
var blen = len(msg.buffer)
|
||||
if len(data) >= blen:
|
||||
|
@ -236,8 +236,8 @@ proc toBytes*(key: PublicKey, data: var openarray[byte]): CryptoResult[int] =
|
|||
##
|
||||
## Returns number of bytes (octets) needed to store public key ``key``.
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
var blen = len(msg.buffer)
|
||||
if len(data) >= blen and blen > 0:
|
||||
|
@ -256,8 +256,8 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] =
|
|||
## Return private key ``key`` in binary form (using libp2p's protobuf
|
||||
## serialization).
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
ok(msg.buffer)
|
||||
|
||||
|
@ -265,8 +265,8 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] =
|
|||
## Return public key ``key`` in binary form (using libp2p's protobuf
|
||||
## serialization).
|
||||
var msg = initProtoBuffer()
|
||||
msg.write(initProtoField(1, cast[uint64](key.scheme)))
|
||||
msg.write(initProtoField(2, ? key.getRawBytes()))
|
||||
msg.write(1, uint64(key.scheme))
|
||||
msg.write(2, ? key.getRawBytes())
|
||||
msg.finish()
|
||||
ok(msg.buffer)
|
||||
|
||||
|
@ -283,33 +283,32 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool =
|
|||
var buffer: seq[byte]
|
||||
if len(data) > 0:
|
||||
var pb = initProtoBuffer(@data)
|
||||
if pb.getVarintValue(1, id) != 0:
|
||||
if pb.getBytes(2, buffer) != 0:
|
||||
if cast[int8](id) in SupportedSchemesInt:
|
||||
var scheme = cast[PKScheme](cast[int8](id))
|
||||
when key is PrivateKey:
|
||||
var nkey = PrivateKey(scheme: scheme)
|
||||
else:
|
||||
var nkey = PublicKey(scheme: scheme)
|
||||
case scheme:
|
||||
of PKScheme.RSA:
|
||||
if init(nkey.rsakey, buffer).isOk:
|
||||
key = nkey
|
||||
return true
|
||||
of PKScheme.Ed25519:
|
||||
if init(nkey.edkey, buffer):
|
||||
key = nkey
|
||||
return true
|
||||
of PKScheme.ECDSA:
|
||||
if init(nkey.eckey, buffer).isOk:
|
||||
key = nkey
|
||||
return true
|
||||
of PKScheme.Secp256k1:
|
||||
if init(nkey.skkey, buffer).isOk:
|
||||
key = nkey
|
||||
return true
|
||||
else:
|
||||
return false
|
||||
if pb.getField(1, id) and pb.getField(2, buffer):
|
||||
if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0:
|
||||
var scheme = cast[PKScheme](cast[int8](id))
|
||||
when key is PrivateKey:
|
||||
var nkey = PrivateKey(scheme: scheme)
|
||||
else:
|
||||
var nkey = PublicKey(scheme: scheme)
|
||||
case scheme:
|
||||
of PKScheme.RSA:
|
||||
if init(nkey.rsakey, buffer).isOk:
|
||||
key = nkey
|
||||
return true
|
||||
of PKScheme.Ed25519:
|
||||
if init(nkey.edkey, buffer):
|
||||
key = nkey
|
||||
return true
|
||||
of PKScheme.ECDSA:
|
||||
if init(nkey.eckey, buffer).isOk:
|
||||
key = nkey
|
||||
return true
|
||||
of PKScheme.Secp256k1:
|
||||
if init(nkey.skkey, buffer).isOk:
|
||||
key = nkey
|
||||
return true
|
||||
else:
|
||||
return false
|
||||
|
||||
proc init*(sig: var Signature, data: openarray[byte]): bool =
|
||||
## Initialize signature ``sig`` from raw binary form.
|
||||
|
@ -374,6 +373,24 @@ proc init*(t: typedesc[PrivateKey], data: string): CryptoResult[PrivateKey] =
|
|||
except ValueError:
|
||||
err(KeyError)
|
||||
|
||||
proc init*(t: typedesc[PrivateKey], key: rsa.RsaPrivateKey): PrivateKey =
|
||||
PrivateKey(scheme: RSA, rsakey: key)
|
||||
proc init*(t: typedesc[PrivateKey], key: EdPrivateKey): PrivateKey =
|
||||
PrivateKey(scheme: Ed25519, edkey: key)
|
||||
proc init*(t: typedesc[PrivateKey], key: SkPrivateKey): PrivateKey =
|
||||
PrivateKey(scheme: Secp256k1, skkey: key)
|
||||
proc init*(t: typedesc[PrivateKey], key: ecnist.EcPrivateKey): PrivateKey =
|
||||
PrivateKey(scheme: ECDSA, eckey: key)
|
||||
|
||||
proc init*(t: typedesc[PublicKey], key: rsa.RsaPublicKey): PublicKey =
|
||||
PublicKey(scheme: RSA, rsakey: key)
|
||||
proc init*(t: typedesc[PublicKey], key: EdPublicKey): PublicKey =
|
||||
PublicKey(scheme: Ed25519, edkey: key)
|
||||
proc init*(t: typedesc[PublicKey], key: SkPublicKey): PublicKey =
|
||||
PublicKey(scheme: Secp256k1, skkey: key)
|
||||
proc init*(t: typedesc[PublicKey], key: ecnist.EcPublicKey): PublicKey =
|
||||
PublicKey(scheme: ECDSA, eckey: key)
|
||||
|
||||
proc init*(t: typedesc[PublicKey], data: string): CryptoResult[PublicKey] =
|
||||
## Create new public key from libp2p's protobuf serialized hexadecimal string
|
||||
## form.
|
||||
|
@ -709,11 +726,11 @@ proc createProposal*(nonce, pubkey: openarray[byte],
|
|||
## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and
|
||||
## comma-delimeted list of supported hashes ``hashes``.
|
||||
var msg = initProtoBuffer({WithUint32BeLength})
|
||||
msg.write(initProtoField(1, nonce))
|
||||
msg.write(initProtoField(2, pubkey))
|
||||
msg.write(initProtoField(3, exchanges))
|
||||
msg.write(initProtoField(4, ciphers))
|
||||
msg.write(initProtoField(5, hashes))
|
||||
msg.write(1, nonce)
|
||||
msg.write(2, pubkey)
|
||||
msg.write(3, exchanges)
|
||||
msg.write(4, ciphers)
|
||||
msg.write(5, hashes)
|
||||
msg.finish()
|
||||
shallowCopy(result, msg.buffer)
|
||||
|
||||
|
@ -726,19 +743,16 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
|
|||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
if pb.getLengthValue(1, nonce) != -1 and
|
||||
pb.getLengthValue(2, pubkey) != -1 and
|
||||
pb.getLengthValue(3, exchanges) != -1 and
|
||||
pb.getLengthValue(4, ciphers) != -1 and
|
||||
pb.getLengthValue(5, hashes) != -1:
|
||||
result = true
|
||||
pb.getField(1, nonce) and pb.getField(2, pubkey) and
|
||||
pb.getField(3, exchanges) and pb.getField(4, ciphers) and
|
||||
pb.getField(5, hashes)
|
||||
|
||||
proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] =
|
||||
## Create SecIO exchange message using ephemeral public key ``epubkey`` and
|
||||
## signature of proposal blocks ``signature``.
|
||||
var msg = initProtoBuffer({WithUint32BeLength})
|
||||
msg.write(initProtoField(1, epubkey))
|
||||
msg.write(initProtoField(2, signature))
|
||||
msg.write(1, epubkey)
|
||||
msg.write(2, signature)
|
||||
msg.finish()
|
||||
shallowCopy(result, msg.buffer)
|
||||
|
||||
|
@ -749,9 +763,7 @@ proc decodeExchange*(message: seq[byte],
|
|||
##
|
||||
## Procedure returns ``true`` on success and ``false`` on error.
|
||||
var pb = initProtoBuffer(message)
|
||||
if pb.getLengthValue(1, pubkey) != -1 and
|
||||
pb.getLengthValue(2, signature) != -1:
|
||||
result = true
|
||||
pb.getField(1, pubkey) and pb.getField(2, signature)
|
||||
|
||||
## Serialization/Deserialization helpers
|
||||
|
||||
|
@ -770,22 +782,27 @@ proc write*(vb: var VBuffer, sig: PrivateKey) {.
|
|||
## Write Signature value ``sig`` to buffer ``vb``.
|
||||
vb.writeSeq(sig.getBytes().tryGet())
|
||||
|
||||
proc initProtoField*(index: int, pubkey: PublicKey): ProtoField {.
|
||||
raises: [Defect, ResultError[CryptoError]].} =
|
||||
## Initialize ProtoField with PublicKey ``pubkey``.
|
||||
result = initProtoField(index, pubkey.getBytes().tryGet())
|
||||
proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int,
|
||||
key: T) {.
|
||||
inline, raises: [Defect, ResultError[CryptoError]].} =
|
||||
write(pb, field, key.getBytes().tryGet())
|
||||
|
||||
proc initProtoField*(index: int, seckey: PrivateKey): ProtoField {.
|
||||
raises: [Defect, ResultError[CryptoError]].} =
|
||||
## Initialize ProtoField with PrivateKey ``seckey``.
|
||||
result = initProtoField(index, seckey.getBytes().tryGet())
|
||||
proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {.
|
||||
inline, raises: [Defect, ResultError[CryptoError]].} =
|
||||
write(pb, field, sig.getBytes())
|
||||
|
||||
proc initProtoField*(index: int, sig: Signature): ProtoField =
|
||||
proc initProtoField*(index: int, key: PublicKey|PrivateKey): ProtoField {.
|
||||
deprecated, raises: [Defect, ResultError[CryptoError]].} =
|
||||
## Initialize ProtoField with PublicKey/PrivateKey ``key``.
|
||||
result = initProtoField(index, key.getBytes().tryGet())
|
||||
|
||||
proc initProtoField*(index: int, sig: Signature): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with Signature ``sig``.
|
||||
result = initProtoField(index, sig.getBytes())
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
|
||||
## Read ``PublicKey`` from ProtoBuf's message and validate it.
|
||||
proc getValue*[T: PublicKey|PrivateKey](data: var ProtoBuffer, field: int,
|
||||
value: var T): int {.deprecated.} =
|
||||
## Read PublicKey/PrivateKey from ProtoBuf's message and validate it.
|
||||
var buf: seq[byte]
|
||||
var key: PublicKey
|
||||
result = getLengthValue(data, field, buf)
|
||||
|
@ -795,18 +812,8 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
|
|||
else:
|
||||
value = key
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PrivateKey): int =
|
||||
## Read ``PrivateKey`` from ProtoBuf's message and validate it.
|
||||
var buf: seq[byte]
|
||||
var key: PrivateKey
|
||||
result = getLengthValue(data, field, buf)
|
||||
if result > 0:
|
||||
if not key.init(buf):
|
||||
result = -1
|
||||
else:
|
||||
value = key
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {.
|
||||
deprecated.} =
|
||||
## Read ``Signature`` from ProtoBuf's message and validate it.
|
||||
var buf: seq[byte]
|
||||
var sig: Signature
|
||||
|
@ -816,3 +823,30 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
|
|||
result = -1
|
||||
else:
|
||||
value = sig
|
||||
|
||||
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
|
||||
value: var T): bool =
|
||||
var buffer: seq[byte]
|
||||
var key: T
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if key.init(buffer):
|
||||
value = key
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool =
|
||||
var buffer: seq[byte]
|
||||
var sig: Signature
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if sig.init(buffer):
|
||||
value = sig
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
|
|
@ -105,10 +105,13 @@ proc mulgen*(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key)
|
|||
proc public*(private: Curve25519Key): Curve25519Key =
|
||||
Curve25519.mulgen(result, private)
|
||||
|
||||
proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Result[Curve25519Key, Curve25519Error] =
|
||||
proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Curve25519Key =
|
||||
var res: Curve25519Key
|
||||
let defaultBrEc = brEcGetDefault()
|
||||
if brEcKeygen(addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) != Curve25519KeySize:
|
||||
err(Curver25519GenError)
|
||||
else:
|
||||
ok(res)
|
||||
let len = brEcKeygen(
|
||||
addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519)
|
||||
# Per bearssl documentation, the keygen only fails if the curve is
|
||||
# unrecognised -
|
||||
doAssert len == Curve25519KeySize, "Could not generate curve"
|
||||
|
||||
res
|
||||
|
|
|
@ -39,11 +39,11 @@ const
|
|||
|
||||
type
|
||||
EcPrivateKey* = ref object
|
||||
buffer*: seq[byte]
|
||||
buffer*: array[BR_EC_KBUF_PRIV_MAX_SIZE, byte]
|
||||
key*: BrEcPrivateKey
|
||||
|
||||
EcPublicKey* = ref object
|
||||
buffer*: seq[byte]
|
||||
buffer*: array[BR_EC_KBUF_PUB_MAX_SIZE, byte]
|
||||
key*: BrEcPublicKey
|
||||
|
||||
EcKeyPair* = object
|
||||
|
@ -237,7 +237,6 @@ proc random*(
|
|||
## secp521r1).
|
||||
var ecimp = brEcGetDefault()
|
||||
var res = new EcPrivateKey
|
||||
res.buffer = newSeq[byte](BR_EC_KBUF_PRIV_MAX_SIZE)
|
||||
if brEcKeygen(addr rng.vtable, ecimp,
|
||||
addr res.key, addr res.buffer[0],
|
||||
cast[cint](kind)) == 0:
|
||||
|
@ -254,7 +253,6 @@ proc getKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] =
|
|||
if seckey.key.curve in EcSupportedCurvesCint:
|
||||
var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve))
|
||||
var res = new EcPublicKey
|
||||
res.buffer = newSeq[byte](length)
|
||||
if brEcComputePublicKey(ecimp, addr res.key,
|
||||
addr res.buffer[0], unsafeAddr seckey.key) == 0:
|
||||
err(EcKeyIncorrectError)
|
||||
|
@ -621,7 +619,6 @@ proc init*(key: var EcPrivateKey, data: openarray[byte]): Result[void, Asn1Error
|
|||
|
||||
if checkScalar(raw.toOpenArray(), curve) == 1'u32:
|
||||
key = new EcPrivateKey
|
||||
key.buffer = newSeq[byte](raw.length)
|
||||
copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length)
|
||||
key.key.x = cast[ptr cuchar](addr key.buffer[0])
|
||||
key.key.xlen = raw.length
|
||||
|
@ -681,7 +678,6 @@ proc init*(pubkey: var EcPublicKey, data: openarray[byte]): Result[void, Asn1Err
|
|||
|
||||
if checkPublic(raw.toOpenArray(), curve) != 0:
|
||||
pubkey = new EcPublicKey
|
||||
pubkey.buffer = newSeq[byte](raw.length)
|
||||
copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length)
|
||||
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
|
||||
pubkey.key.qlen = raw.length
|
||||
|
@ -769,7 +765,6 @@ proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool =
|
|||
if checkScalar(data, curve) == 1'u32:
|
||||
let length = len(data)
|
||||
key = new EcPrivateKey
|
||||
key.buffer = newSeq[byte](length)
|
||||
copyMem(addr key.buffer[0], unsafeAddr data[0], length)
|
||||
key.key.x = cast[ptr cuchar](addr key.buffer[0])
|
||||
key.key.xlen = length
|
||||
|
@ -801,7 +796,6 @@ proc initRaw*(pubkey: var EcPublicKey, data: openarray[byte]): bool =
|
|||
if checkPublic(data, curve) != 0:
|
||||
let length = len(data)
|
||||
pubkey = new EcPublicKey
|
||||
pubkey.buffer = newSeq[byte](length)
|
||||
copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length)
|
||||
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
|
||||
pubkey.key.qlen = length
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
import nativesockets
|
||||
import tables, strutils, stew/shims/net
|
||||
import chronos
|
||||
import multicodec, multihash, multibase, transcoder, vbuffer, peerid
|
||||
import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
|
||||
protobuf/minprotobuf
|
||||
import stew/[base58, base32, endians2, results]
|
||||
export results
|
||||
export results, minprotobuf, vbuffer
|
||||
|
||||
type
|
||||
MAKind* = enum
|
||||
|
@ -477,7 +478,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] =
|
|||
else:
|
||||
ok($(proto.mcodec))
|
||||
|
||||
proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] =
|
||||
proc protoArgument*(ma: MultiAddress,
|
||||
value: var openarray[byte]): MaResult[int] =
|
||||
## Returns MultiAddress ``ma`` protocol argument value.
|
||||
##
|
||||
## If current MultiAddress do not have argument value, then result will be
|
||||
|
@ -496,8 +498,8 @@ proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int]
|
|||
var res: int
|
||||
if proto.kind == Fixed:
|
||||
res = proto.size
|
||||
if len(value) >= res and
|
||||
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
|
||||
if len(value) >= res and
|
||||
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
|
||||
err("multiaddress: Decoding protocol error")
|
||||
else:
|
||||
ok(res)
|
||||
|
@ -580,7 +582,8 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] =
|
|||
|
||||
let proto = CodeAddresses.getOrDefault(MultiCodec(header))
|
||||
if proto.kind == None:
|
||||
yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'")
|
||||
yield err(MaResult[MultiAddress], "Unsupported protocol '" &
|
||||
$header & "'")
|
||||
|
||||
elif proto.kind == Fixed:
|
||||
data.setLen(proto.size)
|
||||
|
@ -609,7 +612,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} =
|
|||
return ok(true)
|
||||
ok(false)
|
||||
|
||||
proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
|
||||
proc `[]`*(ma: MultiAddress,
|
||||
codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
|
||||
## Returns partial MultiAddress with MultiCodec ``codec`` and present in
|
||||
## MultiAddress ``ma``.
|
||||
for item in ma.items:
|
||||
|
@ -634,7 +638,8 @@ proc toString*(value: MultiAddress): MaResult[string] =
|
|||
return err("multiaddress: Unsupported protocol '" & $header & "'")
|
||||
if proto.kind in {Fixed, Length, Path}:
|
||||
if isNil(proto.coder.bufferToString):
|
||||
return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder")
|
||||
return err("multiaddress: Missing protocol '" & $(proto.mcodec) &
|
||||
"' coder")
|
||||
if not proto.coder.bufferToString(vb.data, part):
|
||||
return err("multiaddress: Decoding protocol error")
|
||||
parts.add($(proto.mcodec))
|
||||
|
@ -729,12 +734,14 @@ proc init*(
|
|||
of None:
|
||||
raiseAssert "None checked above"
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerID): MaResult[MultiAddress] {.inline.} =
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
|
||||
value: PeerID): MaResult[MultiAddress] {.inline.} =
|
||||
## Initialize MultiAddress object from protocol id ``protocol`` and peer id
|
||||
## ``value``.
|
||||
init(mtype, protocol, value.data)
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int): MaResult[MultiAddress] =
|
||||
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
|
||||
value: int): MaResult[MultiAddress] =
|
||||
## Initialize MultiAddress object from protocol id ``protocol`` and integer
|
||||
## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``,
|
||||
## ``dccp`` and ``sctp`` MultiAddresses.
|
||||
|
@ -759,7 +766,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} =
|
|||
if mc != InvalidMultiCodec:
|
||||
result = CodeAddresses.getOrDefault(mc)
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] =
|
||||
proc init*(mtype: typedesc[MultiAddress],
|
||||
value: string): MaResult[MultiAddress] =
|
||||
## Initialize MultiAddress object from string representation ``value``.
|
||||
var parts = value.trimRight('/').split('/')
|
||||
if len(parts[0]) != 0:
|
||||
|
@ -776,7 +784,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
|
|||
else:
|
||||
if proto.kind in {Fixed, Length, Path}:
|
||||
if isNil(proto.coder.stringToBuffer):
|
||||
return err("multiaddress: Missing protocol '" & part & "' transcoder")
|
||||
return err("multiaddress: Missing protocol '" &
|
||||
part & "' transcoder")
|
||||
|
||||
if offset + 1 >= len(parts):
|
||||
return err("multiaddress: Missing protocol '" & part & "' argument")
|
||||
|
@ -785,14 +794,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
|
|||
res.data.write(proto.mcodec)
|
||||
let res = proto.coder.stringToBuffer(parts[offset + 1], res.data)
|
||||
if not res:
|
||||
return err("multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`")
|
||||
return err("multiaddress: Error encoding `" & part & "/" &
|
||||
parts[offset + 1] & "`")
|
||||
offset += 2
|
||||
|
||||
elif proto.kind == Path:
|
||||
var path = "/" & (parts[(offset + 1)..^1].join("/"))
|
||||
res.data.write(proto.mcodec)
|
||||
if not proto.coder.stringToBuffer(path, res.data):
|
||||
return err("multiaddress: Error encoding `" & part & "/" & path & "`")
|
||||
return err("multiaddress: Error encoding `" & part & "/" &
|
||||
path & "`")
|
||||
|
||||
break
|
||||
elif proto.kind == Marker:
|
||||
|
@ -801,8 +812,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
|
|||
res.data.finish()
|
||||
ok(res)
|
||||
|
||||
|
||||
proc init*(mtype: typedesc[MultiAddress], data: openarray[byte]): MaResult[MultiAddress] =
|
||||
proc init*(mtype: typedesc[MultiAddress],
|
||||
data: openarray[byte]): MaResult[MultiAddress] =
|
||||
## Initialize MultiAddress with array of bytes ``data``.
|
||||
if len(data) == 0:
|
||||
err("multiaddress: Address could not be empty!")
|
||||
|
@ -836,10 +847,12 @@ proc init*(mtype: typedesc[MultiAddress],
|
|||
var data = initVBuffer()
|
||||
data.write(familyProto.mcodec)
|
||||
var written = familyProto.coder.stringToBuffer($address, data)
|
||||
doAssert written, "Merely writing a string to a buffer should always be possible"
|
||||
doAssert written,
|
||||
"Merely writing a string to a buffer should always be possible"
|
||||
data.write(protoProto.mcodec)
|
||||
written = protoProto.coder.stringToBuffer($port, data)
|
||||
doAssert written, "Merely writing a string to a buffer should always be possible"
|
||||
doAssert written,
|
||||
"Merely writing a string to a buffer should always be possible"
|
||||
data.finish()
|
||||
|
||||
MultiAddress(data: data)
|
||||
|
@ -890,14 +903,16 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] =
|
|||
else:
|
||||
ok()
|
||||
|
||||
proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [Defect, ResultError[string]].} =
|
||||
proc `&`*(m1, m2: MultiAddress): MultiAddress {.
|
||||
raises: [Defect, ResultError[string]].} =
|
||||
## Concatenates two addresses ``m1`` and ``m2``, and returns result.
|
||||
##
|
||||
## This procedure performs validation of concatenated result and can raise
|
||||
## exception on error.
|
||||
concat(m1, m2).tryGet()
|
||||
|
||||
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [Defect, ResultError[string]].} =
|
||||
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.
|
||||
raises: [Defect, ResultError[string]].} =
|
||||
## Concatenates two addresses ``m1`` and ``m2``.
|
||||
##
|
||||
## This procedure performs validation of concatenated result and can raise
|
||||
|
@ -1005,3 +1020,36 @@ proc `$`*(pat: MaPattern): string =
|
|||
result = "(" & sub.join("|") & ")"
|
||||
elif pat.operator == Eq:
|
||||
result = $pat.value
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
|
||||
write(pb, field, value.data.buffer)
|
||||
|
||||
proc getField*(pb: var ProtoBuffer, field: int,
|
||||
value: var MultiAddress): bool {.inline.} =
|
||||
var buffer: seq[byte]
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
let ma = MultiAddress.init(buffer)
|
||||
if ma.isOk():
|
||||
value = ma.get()
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getRepeatedField*(pb: var ProtoBuffer, field: int,
|
||||
value: var seq[MultiAddress]): bool {.inline.} =
|
||||
var items: seq[seq[byte]]
|
||||
value.setLen(0)
|
||||
if not(getRepeatedField(pb, field, items)):
|
||||
return false
|
||||
if len(items) == 0:
|
||||
return true
|
||||
for item in items:
|
||||
let ma = MultiAddress.init(item)
|
||||
if ma.isOk():
|
||||
value.add(ma.get())
|
||||
else:
|
||||
value.setLen(0)
|
||||
return false
|
||||
|
|
|
@ -11,7 +11,6 @@ import strutils, tables
|
|||
import chronos, chronicles, stew/byteutils
|
||||
import stream/connection,
|
||||
vbuffer,
|
||||
errors,
|
||||
protocols/protocol
|
||||
|
||||
logScope:
|
||||
|
|
|
@ -14,8 +14,6 @@ import types,
|
|||
nimcrypto/utils,
|
||||
../../stream/connection,
|
||||
../../stream/bufferstream,
|
||||
../../utility,
|
||||
../../errors,
|
||||
../../peerinfo
|
||||
|
||||
export connection
|
||||
|
@ -189,14 +187,11 @@ proc closeRemote*(s: LPChannel) {.async.} =
|
|||
# stack = getStackTrace()
|
||||
|
||||
trace "got EOF, closing channel"
|
||||
|
||||
# wait for all data in the buffer to be consumed
|
||||
while s.len > 0:
|
||||
await s.dataReadEvent.wait()
|
||||
s.dataReadEvent.clear()
|
||||
await s.drainBuffer()
|
||||
|
||||
s.isEof = true # set EOF immediately to prevent further reads
|
||||
await s.close() # close local end
|
||||
|
||||
# call to avoid leaks
|
||||
await procCall BufferStream(s).close() # close parent bufferstream
|
||||
trace "channel closed on EOF"
|
||||
|
@ -227,7 +222,11 @@ method reset*(s: LPChannel) {.base, async, gcsafe.} =
|
|||
# might be dead already - reset is always
|
||||
# optimistic
|
||||
asyncCheck s.resetMessage()
|
||||
|
||||
# drain the buffer before closing
|
||||
await s.drainBuffer()
|
||||
await procCall BufferStream(s).close()
|
||||
|
||||
s.isEof = true
|
||||
s.closedLocal = true
|
||||
|
||||
|
@ -254,11 +253,11 @@ method close*(s: LPChannel) {.async, gcsafe.} =
|
|||
if s.atEof: # already closed by remote close parent buffer immediately
|
||||
await procCall BufferStream(s).close()
|
||||
except CancelledError as exc:
|
||||
await s.reset() # reset on timeout
|
||||
await s.reset()
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "exception closing channel"
|
||||
await s.reset() # reset on timeout
|
||||
await s.reset()
|
||||
|
||||
trace "lpchannel closed local"
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ import ../muxer,
|
|||
../../stream/connection,
|
||||
../../stream/bufferstream,
|
||||
../../utility,
|
||||
../../errors,
|
||||
../../peerinfo,
|
||||
coder,
|
||||
types,
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
import chronos, chronicles
|
||||
import ../protocols/protocol,
|
||||
../stream/connection,
|
||||
../peerinfo,
|
||||
../errors
|
||||
|
||||
logScope:
|
||||
|
|
|
@ -200,11 +200,12 @@ proc write*(vb: var VBuffer, pid: PeerID) {.inline.} =
|
|||
## Write PeerID value ``peerid`` to buffer ``vb``.
|
||||
vb.writeSeq(pid.data)
|
||||
|
||||
proc initProtoField*(index: int, pid: PeerID): ProtoField =
|
||||
proc initProtoField*(index: int, pid: PeerID): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with PeerID ``value``.
|
||||
result = initProtoField(index, pid.data)
|
||||
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
|
||||
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int {.
|
||||
deprecated.} =
|
||||
## Read ``PeerID`` from ProtoBuf's message and validate it.
|
||||
var pid: PeerID
|
||||
result = getLengthValue(data, field, pid.data)
|
||||
|
@ -213,3 +214,21 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
|
|||
result = -1
|
||||
else:
|
||||
value = pid
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) =
|
||||
## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
|
||||
write(pb, field, pid.data)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool =
|
||||
## Read ``PeerID`` from ProtoBuf's message and validate it
|
||||
var buffer: seq[byte]
|
||||
var peerId: PeerID
|
||||
if not(getField(pb, field, buffer)):
|
||||
return false
|
||||
if len(buffer) == 0:
|
||||
return false
|
||||
if peerId.init(buffer):
|
||||
pid = peerId
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
{.push raises: [Defect].}
|
||||
|
||||
import ../varint
|
||||
import ../varint, stew/endians2
|
||||
|
||||
const
|
||||
MaxMessageSize* = 1'u shl 22
|
||||
|
@ -32,10 +32,14 @@ type
|
|||
offset*: int
|
||||
length*: int
|
||||
|
||||
ProtoHeader* = object
|
||||
wire*: ProtoFieldKind
|
||||
index*: uint64
|
||||
|
||||
ProtoField* = object
|
||||
## Protobuf's message field representation object
|
||||
index: int
|
||||
case kind: ProtoFieldKind
|
||||
index*: int
|
||||
case kind*: ProtoFieldKind
|
||||
of Varint:
|
||||
vint*: uint64
|
||||
of Fixed64:
|
||||
|
@ -47,13 +51,35 @@ type
|
|||
of StartGroup, EndGroup:
|
||||
discard
|
||||
|
||||
template protoHeader*(index: int, wire: ProtoFieldKind): uint =
|
||||
## Get protobuf's field header integer for ``index`` and ``wire``.
|
||||
((uint(index) shl 3) or cast[uint](wire))
|
||||
ProtoResult {.pure.} = enum
|
||||
VarintDecodeError,
|
||||
MessageIncompleteError,
|
||||
BufferOverflowError,
|
||||
MessageSizeTooBigError,
|
||||
NoError
|
||||
|
||||
template protoHeader*(field: ProtoField): uint =
|
||||
ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
|
||||
hint | hint32 | hint64 | float32 | float64
|
||||
|
||||
const
|
||||
SupportedWireTypes* = {
|
||||
int(ProtoFieldKind.Varint),
|
||||
int(ProtoFieldKind.Fixed64),
|
||||
int(ProtoFieldKind.Length),
|
||||
int(ProtoFieldKind.Fixed32)
|
||||
}
|
||||
|
||||
template checkFieldNumber*(i: int) =
|
||||
doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999),
|
||||
"Incorrect or reserved field number")
|
||||
|
||||
template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 =
|
||||
## Get protobuf's field header integer for ``index`` and ``wire``.
|
||||
((uint64(index) shl 3) or uint64(wire))
|
||||
|
||||
template getProtoHeader*(field: ProtoField): uint64 =
|
||||
## Get protobuf's field header integer for ``field``.
|
||||
((uint(field.index) shl 3) or cast[uint](field.kind))
|
||||
((uint64(field.index) shl 3) or uint64(field.kind))
|
||||
|
||||
template toOpenArray*(pb: ProtoBuffer): untyped =
|
||||
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)
|
||||
|
@ -72,20 +98,20 @@ template getLen*(pb: ProtoBuffer): int =
|
|||
|
||||
proc vsizeof*(field: ProtoField): int {.inline.} =
|
||||
## Returns number of bytes required to store protobuf's field ``field``.
|
||||
result = vsizeof(protoHeader(field))
|
||||
case field.kind
|
||||
of ProtoFieldKind.Varint:
|
||||
result += vsizeof(field.vint)
|
||||
vsizeof(getProtoHeader(field)) + vsizeof(field.vint)
|
||||
of ProtoFieldKind.Fixed64:
|
||||
result += sizeof(field.vfloat64)
|
||||
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat64)
|
||||
of ProtoFieldKind.Fixed32:
|
||||
result += sizeof(field.vfloat32)
|
||||
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32)
|
||||
of ProtoFieldKind.Length:
|
||||
result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer)
|
||||
vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) +
|
||||
len(field.vbuffer)
|
||||
else:
|
||||
discard
|
||||
0
|
||||
|
||||
proc initProtoField*(index: int, value: SomeVarint): ProtoField =
|
||||
proc initProtoField*(index: int, value: SomeVarint): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with integer value.
|
||||
result = ProtoField(kind: Varint, index: index)
|
||||
when type(value) is uint64:
|
||||
|
@ -93,26 +119,28 @@ proc initProtoField*(index: int, value: SomeVarint): ProtoField =
|
|||
else:
|
||||
result.vint = cast[uint64](value)
|
||||
|
||||
proc initProtoField*(index: int, value: bool): ProtoField =
|
||||
proc initProtoField*(index: int, value: bool): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with integer value.
|
||||
result = ProtoField(kind: Varint, index: index)
|
||||
result.vint = byte(value)
|
||||
|
||||
proc initProtoField*(index: int, value: openarray[byte]): ProtoField =
|
||||
proc initProtoField*(index: int,
|
||||
value: openarray[byte]): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with bytes array.
|
||||
result = ProtoField(kind: Length, index: index)
|
||||
if len(value) > 0:
|
||||
result.vbuffer = newSeq[byte](len(value))
|
||||
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
|
||||
|
||||
proc initProtoField*(index: int, value: string): ProtoField =
|
||||
proc initProtoField*(index: int, value: string): ProtoField {.deprecated.} =
|
||||
## Initialize ProtoField with string.
|
||||
result = ProtoField(kind: Length, index: index)
|
||||
if len(value) > 0:
|
||||
result.vbuffer = newSeq[byte](len(value))
|
||||
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
|
||||
|
||||
proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} =
|
||||
proc initProtoField*(index: int,
|
||||
value: ProtoBuffer): ProtoField {.deprecated, inline.} =
|
||||
## Initialize ProtoField with nested message stored in ``value``.
|
||||
##
|
||||
## Note: This procedure performs shallow copy of ``value`` sequence.
|
||||
|
@ -127,6 +155,13 @@ proc initProtoBuffer*(data: seq[byte], offset = 0,
|
|||
result.offset = offset
|
||||
result.options = options
|
||||
|
||||
proc initProtoBuffer*(data: openarray[byte], offset = 0,
|
||||
options: set[ProtoFlags] = {}): ProtoBuffer =
|
||||
## Initialize ProtoBuffer with copy of ``data``.
|
||||
result.buffer = @data
|
||||
result.offset = offset
|
||||
result.options = options
|
||||
|
||||
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
|
||||
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
|
||||
result.buffer = newSeqOfCap[byte](128)
|
||||
|
@ -138,16 +173,134 @@ proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
|
|||
result.offset = 10
|
||||
elif {WithUint32LeLength, WithUint32BeLength} * options != {}:
|
||||
# Our buffer will start from position 4, so we can store length of buffer
|
||||
# in [0, 9].
|
||||
# in [0, 3].
|
||||
result.buffer.setLen(4)
|
||||
result.offset = 4
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: ProtoField) =
|
||||
proc write*[T: ProtoScalar](pb: var ProtoBuffer,
|
||||
field: int, value: T) =
|
||||
checkFieldNumber(field)
|
||||
var length = 0
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) +
|
||||
vsizeof(value)
|
||||
let header = ProtoFieldKind.Varint
|
||||
elif T is float32:
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) +
|
||||
sizeof(T)
|
||||
let header = ProtoFieldKind.Fixed32
|
||||
elif T is float64:
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) +
|
||||
sizeof(T)
|
||||
let header = ProtoFieldKind.Fixed64
|
||||
|
||||
pb.buffer.setLen(len(pb.buffer) + flength)
|
||||
|
||||
let hres = PB.putUVarint(pb.toOpenArray(), length,
|
||||
getProtoHeader(field, header))
|
||||
doAssert(hres.isOk())
|
||||
pb.offset += length
|
||||
when (T is uint64) or (T is uint32) or (T is uint):
|
||||
let vres = PB.putUVarint(pb.toOpenArray(), length, value)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif (T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
let vres = putSVarint(pb.toOpenArray(), length, value)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif T is float32:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u32 = cast[uint32](value)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
elif T is float64:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u64 = cast[uint64](value)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
|
||||
proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int,
|
||||
value: openarray[T]) =
|
||||
checkFieldNumber(field)
|
||||
var length = 0
|
||||
let dlength =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
var res = 0
|
||||
for item in value:
|
||||
res += vsizeof(item)
|
||||
res
|
||||
elif (T is float32) or (T is float64):
|
||||
len(value) * sizeof(T)
|
||||
|
||||
let header = getProtoHeader(field, ProtoFieldKind.Length)
|
||||
let flength = vsizeof(header) + vsizeof(uint64(dlength)) + dlength
|
||||
pb.buffer.setLen(len(pb.buffer) + flength)
|
||||
let hres = PB.putUVarint(pb.toOpenArray(), length, header)
|
||||
doAssert(hres.isOk())
|
||||
pb.offset += length
|
||||
length = 0
|
||||
let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(dlength))
|
||||
doAssert(lres.isOk())
|
||||
pb.offset += length
|
||||
for item in value:
|
||||
when (T is uint64) or (T is uint32) or (T is uint):
|
||||
length = 0
|
||||
let vres = PB.putUVarint(pb.toOpenArray(), length, item)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif (T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
length = 0
|
||||
let vres = PB.putSVarint(pb.toOpenArray(), length, item)
|
||||
doAssert(vres.isOk())
|
||||
pb.offset += length
|
||||
elif T is float32:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u32 = cast[uint32](item)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
elif T is float64:
|
||||
doAssert(pb.isEnough(sizeof(T)))
|
||||
let u64 = cast[uint64](item)
|
||||
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
|
||||
pb.offset += sizeof(T)
|
||||
|
||||
proc write*[T: byte|char](pb: var ProtoBuffer, field: int,
|
||||
value: openarray[T]) =
|
||||
checkFieldNumber(field)
|
||||
var length = 0
|
||||
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) +
|
||||
vsizeof(uint64(len(value))) + len(value)
|
||||
pb.buffer.setLen(len(pb.buffer) + flength)
|
||||
let hres = PB.putUVarint(pb.toOpenArray(), length,
|
||||
getProtoHeader(field, ProtoFieldKind.Length))
|
||||
doAssert(hres.isOk())
|
||||
pb.offset += length
|
||||
let lres = PB.putUVarint(pb.toOpenArray(), length,
|
||||
uint64(len(value)))
|
||||
doAssert(lres.isOk())
|
||||
pb.offset += length
|
||||
if len(value) > 0:
|
||||
doAssert(pb.isEnough(len(value)))
|
||||
copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value))
|
||||
pb.offset += len(value)
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: int, value: ProtoBuffer) {.inline.} =
|
||||
## Encode Protobuf's sub-message ``value`` and store it to protobuf's buffer
|
||||
## ``pb`` with field number ``field``.
|
||||
write(pb, field, value.buffer)
|
||||
|
||||
proc write*(pb: var ProtoBuffer, field: ProtoField) {.deprecated.} =
|
||||
## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``.
|
||||
var length = 0
|
||||
var res: VarintResult[void]
|
||||
pb.buffer.setLen(len(pb.buffer) + vsizeof(field))
|
||||
res = PB.putUVarint(pb.toOpenArray(), length, protoHeader(field))
|
||||
res = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field))
|
||||
doAssert(res.isOk())
|
||||
pb.offset += length
|
||||
case field.kind
|
||||
|
@ -199,31 +352,440 @@ proc finish*(pb: var ProtoBuffer) =
|
|||
pb.offset = pos
|
||||
elif WithUint32BeLength in pb.options:
|
||||
let size = uint(len(pb.buffer) - 4)
|
||||
pb.buffer[0] = byte((size shr 24) and 0xFF'u)
|
||||
pb.buffer[1] = byte((size shr 16) and 0xFF'u)
|
||||
pb.buffer[2] = byte((size shr 8) and 0xFF'u)
|
||||
pb.buffer[3] = byte(size and 0xFF'u)
|
||||
pb.buffer[0 ..< 4] = toBytesBE(uint32(size))
|
||||
pb.offset = 4
|
||||
elif WithUint32LeLength in pb.options:
|
||||
let size = uint(len(pb.buffer) - 4)
|
||||
pb.buffer[0] = byte(size and 0xFF'u)
|
||||
pb.buffer[1] = byte((size shr 8) and 0xFF'u)
|
||||
pb.buffer[2] = byte((size shr 16) and 0xFF'u)
|
||||
pb.buffer[3] = byte((size shr 24) and 0xFF'u)
|
||||
pb.buffer[0 ..< 4] = toBytesLE(uint32(size))
|
||||
pb.offset = 4
|
||||
else:
|
||||
pb.offset = 0
|
||||
|
||||
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
|
||||
var length = 0
|
||||
var hdr = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
|
||||
let index = uint64(hdr shr 3)
|
||||
let wire = hdr and 0x07
|
||||
if wire in SupportedWireTypes:
|
||||
data.offset += length
|
||||
header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire))
|
||||
true
|
||||
else:
|
||||
false
|
||||
else:
|
||||
false
|
||||
|
||||
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
|
||||
case header.wire
|
||||
of ProtoFieldKind.Varint:
|
||||
var length = 0
|
||||
var value = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
true
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.Fixed32:
|
||||
if data.isEnough(sizeof(uint32)):
|
||||
data.offset += sizeof(uint32)
|
||||
true
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.Fixed64:
|
||||
if data.isEnough(sizeof(uint64)):
|
||||
data.offset += sizeof(uint64)
|
||||
true
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.Length:
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||
data.offset += length
|
||||
if bsize <= uint64(MaxMessageSize):
|
||||
if data.isEnough(int(bsize)):
|
||||
data.offset += int(bsize)
|
||||
true
|
||||
else:
|
||||
false
|
||||
else:
|
||||
false
|
||||
else:
|
||||
false
|
||||
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
|
||||
false
|
||||
|
||||
proc getValue[T: ProtoScalar](data: var ProtoBuffer,
|
||||
header: ProtoHeader,
|
||||
outval: var T): ProtoResult =
|
||||
when (T is uint64) or (T is uint32) or (T is uint):
|
||||
doAssert(header.wire == ProtoFieldKind.Varint)
|
||||
var length = 0
|
||||
var value = T(0)
|
||||
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
outval = value
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
elif (T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
doAssert(header.wire == ProtoFieldKind.Varint)
|
||||
var length = 0
|
||||
var value = T(0)
|
||||
if getSVarint(data.toOpenArray(), length, value).isOk():
|
||||
data.offset += length
|
||||
outval = value
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
elif T is float32:
|
||||
doAssert(header.wire == ProtoFieldKind.Fixed32)
|
||||
if data.isEnough(sizeof(float32)):
|
||||
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
|
||||
data.offset += sizeof(float32)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
elif T is float64:
|
||||
doAssert(header.wire == ProtoFieldKind.Fixed64)
|
||||
if data.isEnough(sizeof(float64)):
|
||||
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
|
||||
data.offset += sizeof(float64)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
|
||||
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
|
||||
outBytes: var openarray[T],
|
||||
outLength: var int): ProtoResult =
|
||||
doAssert(header.wire == ProtoFieldKind.Length)
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
|
||||
outLength = 0
|
||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||
data.offset += length
|
||||
if bsize <= uint64(MaxMessageSize):
|
||||
if data.isEnough(int(bsize)):
|
||||
outLength = int(bsize)
|
||||
if len(outBytes) >= int(bsize):
|
||||
if bsize > 0'u64:
|
||||
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
# Buffer overflow should not be critical failure
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.BufferOverflowError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
else:
|
||||
ProtoResult.MessageSizeTooBigError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
|
||||
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
|
||||
outBytes: var T): ProtoResult =
|
||||
doAssert(header.wire == ProtoFieldKind.Length)
|
||||
var length = 0
|
||||
var bsize = 0'u64
|
||||
outBytes.setLen(0)
|
||||
|
||||
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
|
||||
data.offset += length
|
||||
if bsize <= uint64(MaxMessageSize):
|
||||
if data.isEnough(int(bsize)):
|
||||
outBytes.setLen(bsize)
|
||||
if bsize > 0'u64:
|
||||
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
|
||||
data.offset += int(bsize)
|
||||
ProtoResult.NoError
|
||||
else:
|
||||
ProtoResult.MessageIncompleteError
|
||||
else:
|
||||
ProtoResult.MessageSizeTooBigError
|
||||
else:
|
||||
ProtoResult.VarintDecodeError
|
||||
|
||||
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
||||
output: var T): bool =
|
||||
checkFieldNumber(field)
|
||||
var value: T
|
||||
var res = false
|
||||
var pb = data
|
||||
output = T(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output = T(0)
|
||||
return false
|
||||
let wireCheck =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
header.wire == ProtoFieldKind.Varint
|
||||
elif T is float32:
|
||||
header.wire == ProtoFieldKind.Fixed32
|
||||
elif T is float64:
|
||||
header.wire == ProtoFieldKind.Fixed64
|
||||
if header.index == uint64(field):
|
||||
if wireCheck:
|
||||
let r = getValue(pb, header, value)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
res = true
|
||||
output = value
|
||||
else:
|
||||
return false
|
||||
else:
|
||||
# We are ignoring wire types different from what we expect, because it
|
||||
# is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
output = T(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output = T(0)
|
||||
return false
|
||||
res
|
||||
|
||||
proc getField*[T: byte|char](data: ProtoBuffer, field: int,
|
||||
output: var openarray[T],
|
||||
outlen: var int): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
var res = false
|
||||
|
||||
outlen = 0
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
let r = getValue(pb, header, output, outlen)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
res = true
|
||||
of ProtoResult.BufferOverflowError:
|
||||
# Buffer overflow error is not critical error, we still can get
|
||||
# field values with proper size.
|
||||
discard
|
||||
else:
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
return false
|
||||
else:
|
||||
# We are ignoring wire types different from ProtoFieldKind.Length,
|
||||
# because it is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
if len(output) > 0:
|
||||
zeroMem(addr output[0], len(output))
|
||||
outlen = 0
|
||||
return false
|
||||
|
||||
res
|
||||
|
||||
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
|
||||
output: var T): bool =
|
||||
checkFieldNumber(field)
|
||||
var res = false
|
||||
var pb = data
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
let r = getValue(pb, header, output)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
res = true
|
||||
of ProtoResult.BufferOverflowError:
|
||||
# Buffer overflow error is not critical error, we still can get
|
||||
# field values with proper size.
|
||||
discard
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
# We are ignoring wire types different from ProtoFieldKind.Length,
|
||||
# because it is how `protoc` is working.
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
res
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {.
|
||||
inline.} =
|
||||
var buffer: seq[byte]
|
||||
if pb.getField(field, buffer):
|
||||
output = initProtoBuffer(buffer)
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
|
||||
output: var seq[T]): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
var item: T
|
||||
let r = getValue(pb, header, item)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer,
|
||||
field: int,
|
||||
output: var seq[T]): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
|
||||
ProtoFieldKind.Fixed64}:
|
||||
var item: T
|
||||
let r = getValue(pb, header, item)
|
||||
case r
|
||||
of ProtoResult.NoError:
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
|
||||
output: var seq[T]): bool =
|
||||
checkFieldNumber(field)
|
||||
var pb = data
|
||||
output.setLen(0)
|
||||
|
||||
while not(pb.isEmpty()):
|
||||
var header: ProtoHeader
|
||||
if not(pb.getHeader(header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if header.index == uint64(field):
|
||||
if header.wire == ProtoFieldKind.Length:
|
||||
var arritem: seq[byte]
|
||||
let rarr = getValue(pb, header, arritem)
|
||||
case rarr
|
||||
of ProtoResult.NoError:
|
||||
var pbarr = initProtoBuffer(arritem)
|
||||
let itemHeader =
|
||||
when (T is uint64) or (T is uint32) or (T is uint) or
|
||||
(T is zint64) or (T is zint32) or (T is zint) or
|
||||
(T is hint64) or (T is hint32) or (T is hint):
|
||||
ProtoHeader(wire: ProtoFieldKind.Varint)
|
||||
elif T is float32:
|
||||
ProtoHeader(wire: ProtoFieldKind.Fixed32)
|
||||
elif T is float64:
|
||||
ProtoHeader(wire: ProtoFieldKind.Fixed64)
|
||||
while not(pbarr.isEmpty()):
|
||||
var item: T
|
||||
let res = getValue(pbarr, itemHeader, item)
|
||||
case res
|
||||
of ProtoResult.NoError:
|
||||
output.add(item)
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
else:
|
||||
if not(skipValue(pb, header)):
|
||||
output.setLen(0)
|
||||
return false
|
||||
|
||||
if len(output) > 0:
|
||||
true
|
||||
else:
|
||||
false
|
||||
|
||||
proc getVarintValue*(data: var ProtoBuffer, field: int,
|
||||
value: var SomeVarint): int =
|
||||
value: var SomeVarint): int {.deprecated.} =
|
||||
## Get value of `Varint` type.
|
||||
var length = 0
|
||||
var header = 0'u64
|
||||
var soffset = data.offset
|
||||
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
|
||||
length, header).isOk():
|
||||
data.offset += length
|
||||
if header == protoHeader(field, Varint):
|
||||
if header == getProtoHeader(field, Varint):
|
||||
if not data.isEmpty():
|
||||
when type(value) is int32 or type(value) is int64 or type(value) is int:
|
||||
let res = getSVarint(data.toOpenArray(), length, value)
|
||||
|
@ -237,7 +799,7 @@ proc getVarintValue*(data: var ProtoBuffer, field: int,
|
|||
data.offset = soffset
|
||||
|
||||
proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
|
||||
buffer: var T): int =
|
||||
buffer: var T): int {.deprecated.} =
|
||||
## Get value of `Length` type.
|
||||
var length = 0
|
||||
var header = 0'u64
|
||||
|
@ -245,10 +807,12 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
|
|||
var soffset = data.offset
|
||||
result = -1
|
||||
buffer.setLen(0)
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
|
||||
length, header).isOk():
|
||||
data.offset += length
|
||||
if header == protoHeader(field, Length):
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, ssize).isOk():
|
||||
if header == getProtoHeader(field, Length):
|
||||
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
|
||||
length, ssize).isOk():
|
||||
data.offset += length
|
||||
if ssize <= MaxMessageSize and data.isEnough(int(ssize)):
|
||||
buffer.setLen(ssize)
|
||||
|
@ -262,16 +826,16 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
|
|||
data.offset = soffset
|
||||
|
||||
proc getBytes*(data: var ProtoBuffer, field: int,
|
||||
buffer: var seq[byte]): int {.inline.} =
|
||||
buffer: var seq[byte]): int {.deprecated, inline.} =
|
||||
## Get value of `Length` type as bytes.
|
||||
result = getLengthValue(data, field, buffer)
|
||||
|
||||
proc getString*(data: var ProtoBuffer, field: int,
|
||||
buffer: var string): int {.inline.} =
|
||||
buffer: var string): int {.deprecated, inline.} =
|
||||
## Get value of `Length` type as string.
|
||||
result = getLengthValue(data, field, buffer)
|
||||
|
||||
proc enterSubmessage*(pb: var ProtoBuffer): int =
|
||||
proc enterSubmessage*(pb: var ProtoBuffer): int {.deprecated.} =
|
||||
## Processes protobuf's sub-message and adjust internal offset to enter
|
||||
## inside of sub-message. Returns field index of sub-message field or
|
||||
## ``0`` on error.
|
||||
|
@ -280,10 +844,12 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
|
|||
var msize = 0'u64
|
||||
var soffset = pb.offset
|
||||
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, header).isOk():
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
|
||||
length, header).isOk():
|
||||
pb.offset += length
|
||||
if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length):
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, msize).isOk():
|
||||
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
|
||||
length, msize).isOk():
|
||||
pb.offset += length
|
||||
if msize <= MaxMessageSize and pb.isEnough(int(msize)):
|
||||
pb.length = int(msize)
|
||||
|
@ -292,7 +858,7 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
|
|||
# Restore offset on error
|
||||
pb.offset = soffset
|
||||
|
||||
proc skipSubmessage*(pb: var ProtoBuffer) =
|
||||
proc skipSubmessage*(pb: var ProtoBuffer) {.deprecated.} =
|
||||
## Skip current protobuf's sub-message and adjust internal offset to the
|
||||
## end of sub-message.
|
||||
doAssert(pb.length != 0)
|
||||
|
|
|
@ -47,61 +47,49 @@ type
|
|||
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
|
||||
result = initProtoBuffer()
|
||||
|
||||
result.write(initProtoField(1, peerInfo.publicKey.get().getBytes().tryGet()))
|
||||
result.write(1, peerInfo.publicKey.get().getBytes().tryGet())
|
||||
|
||||
for ma in peerInfo.addrs:
|
||||
result.write(initProtoField(2, ma.data.buffer))
|
||||
result.write(2, ma.data.buffer)
|
||||
|
||||
for proto in peerInfo.protocols:
|
||||
result.write(initProtoField(3, proto))
|
||||
result.write(3, proto)
|
||||
|
||||
result.write(initProtoField(4, observedAddr.data.buffer))
|
||||
result.write(4, observedAddr.data.buffer)
|
||||
|
||||
let protoVersion = ProtoVersion
|
||||
result.write(initProtoField(5, protoVersion))
|
||||
result.write(5, protoVersion)
|
||||
|
||||
let agentVersion = AgentVersion
|
||||
result.write(initProtoField(6, agentVersion))
|
||||
result.write(6, agentVersion)
|
||||
result.finish()
|
||||
|
||||
proc decodeMsg*(buf: seq[byte]): IdentifyInfo =
|
||||
var pb = initProtoBuffer(buf)
|
||||
|
||||
result.pubKey = none(PublicKey)
|
||||
var pubKey: PublicKey
|
||||
if pb.getValue(1, pubKey) > 0:
|
||||
if pb.getField(1, pubKey):
|
||||
trace "read public key from message", pubKey = ($pubKey).shortLog
|
||||
result.pubKey = some(pubKey)
|
||||
|
||||
result.addrs = newSeq[MultiAddress]()
|
||||
var address = newSeq[byte]()
|
||||
while pb.getBytes(2, address) > 0:
|
||||
if len(address) != 0:
|
||||
var copyaddr = address
|
||||
var ma = MultiAddress.init(copyaddr).tryGet()
|
||||
result.addrs.add(ma)
|
||||
trace "read address bytes from message", address = ma
|
||||
address.setLen(0)
|
||||
if pb.getRepeatedField(2, result.addrs):
|
||||
trace "read addresses from message", addresses = result.addrs
|
||||
|
||||
var proto = ""
|
||||
while pb.getString(3, proto) > 0:
|
||||
trace "read proto from message", proto = proto
|
||||
result.protos.add(proto)
|
||||
proto = ""
|
||||
if pb.getRepeatedField(3, result.protos):
|
||||
trace "read protos from message", protocols = result.protos
|
||||
|
||||
var observableAddr = newSeq[byte]()
|
||||
if pb.getBytes(4, observableAddr) > 0: # attempt to read the observed addr
|
||||
var ma = MultiAddress.init(observableAddr).tryGet()
|
||||
trace "read observedAddr from message", address = ma
|
||||
result.observedAddr = some(ma)
|
||||
var observableAddr: MultiAddress
|
||||
if pb.getField(4, observableAddr):
|
||||
trace "read observableAddr from message", address = observableAddr
|
||||
result.observedAddr = some(observableAddr)
|
||||
|
||||
var protoVersion = ""
|
||||
if pb.getString(5, protoVersion) > 0:
|
||||
if pb.getField(5, protoVersion):
|
||||
trace "read protoVersion from message", protoVersion = protoVersion
|
||||
result.protoVersion = some(protoVersion)
|
||||
|
||||
var agentVersion = ""
|
||||
if pb.getString(6, agentVersion) > 0:
|
||||
if pb.getField(6, agentVersion):
|
||||
trace "read agentVersion from message", agentVersion = agentVersion
|
||||
result.agentVersion = some(agentVersion)
|
||||
|
||||
|
|
|
@ -15,9 +15,7 @@ import pubsub,
|
|||
rpc/[messages, message],
|
||||
../../stream/connection,
|
||||
../../peerid,
|
||||
../../peerinfo,
|
||||
../../utility,
|
||||
../../errors
|
||||
../../peerinfo
|
||||
|
||||
logScope:
|
||||
topics = "floodsub"
|
||||
|
@ -26,7 +24,7 @@ const FloodSubCodec* = "/floodsub/1.0.0"
|
|||
|
||||
type
|
||||
FloodSub* = ref object of PubSub
|
||||
floodsub*: Table[string, HashSet[string]] # topic to remote peer map
|
||||
floodsub*: PeerTable # topic to remote peer map
|
||||
seen*: TimedCache[string] # list of messages forwarded to peers
|
||||
|
||||
method subscribeTopic*(f: FloodSub,
|
||||
|
@ -35,23 +33,28 @@ method subscribeTopic*(f: FloodSub,
|
|||
peerId: string) {.gcsafe, async.} =
|
||||
await procCall PubSub(f).subscribeTopic(topic, subscribe, peerId)
|
||||
|
||||
let peer = f.peers.getOrDefault(peerId)
|
||||
if peer == nil:
|
||||
debug "subscribeTopic on a nil peer!"
|
||||
return
|
||||
|
||||
if topic notin f.floodsub:
|
||||
f.floodsub[topic] = initHashSet[string]()
|
||||
f.floodsub[topic] = initHashSet[PubSubPeer]()
|
||||
|
||||
if subscribe:
|
||||
trace "adding subscription for topic", peer = peerId, name = topic
|
||||
trace "adding subscription for topic", peer = peer.id, name = topic
|
||||
# subscribe the peer to the topic
|
||||
f.floodsub[topic].incl(peerId)
|
||||
f.floodsub[topic].incl(peer)
|
||||
else:
|
||||
trace "removing subscription for topic", peer = peerId, name = topic
|
||||
trace "removing subscription for topic", peer = peer.id, name = topic
|
||||
# unsubscribe the peer from the topic
|
||||
f.floodsub[topic].excl(peerId)
|
||||
f.floodsub[topic].excl(peer)
|
||||
|
||||
method handleDisconnect*(f: FloodSub, peer: PubSubPeer) =
|
||||
## handle peer disconnects
|
||||
for t in toSeq(f.floodsub.keys):
|
||||
if t in f.floodsub:
|
||||
f.floodsub[t].excl(peer.id)
|
||||
f.floodsub[t].excl(peer)
|
||||
|
||||
procCall PubSub(f).handleDisconnect(peer)
|
||||
|
||||
|
@ -62,7 +65,7 @@ method rpcHandler*(f: FloodSub,
|
|||
|
||||
for m in rpcMsgs: # for all RPC messages
|
||||
if m.messages.len > 0: # if there are any messages
|
||||
var toSendPeers: HashSet[string] = initHashSet[string]()
|
||||
var toSendPeers = initHashSet[PubSubPeer]()
|
||||
for msg in m.messages: # for every message
|
||||
let msgId = f.msgIdProvider(msg)
|
||||
logScope: msgId
|
||||
|
@ -138,7 +141,8 @@ method publish*(f: FloodSub,
|
|||
let (published, failed) = await f.sendHelper(f.floodsub.getOrDefault(topic), @[msg])
|
||||
for p in failed:
|
||||
let peer = f.peers.getOrDefault(p)
|
||||
f.handleDisconnect(peer) # cleanup failed peers
|
||||
if not isNil(peer):
|
||||
f.handleDisconnect(peer) # cleanup failed peers
|
||||
|
||||
libp2p_pubsub_messages_published.inc(labelValues = [topic])
|
||||
|
||||
|
@ -157,6 +161,6 @@ method initPubSub*(f: FloodSub) =
|
|||
procCall PubSub(f).initPubSub()
|
||||
f.peers = initTable[string, PubSubPeer]()
|
||||
f.topics = initTable[string, Topic]()
|
||||
f.floodsub = initTable[string, HashSet[string]]()
|
||||
f.floodsub = initTable[string, HashSet[PubSubPeer]]()
|
||||
f.seen = newTimedCache[string](2.minutes)
|
||||
f.init()
|
||||
|
|
|
@ -62,11 +62,11 @@ type
|
|||
|
||||
GossipSub* = ref object of FloodSub
|
||||
parameters*: GossipSubParams
|
||||
mesh*: Table[string, HashSet[string]] # meshes - topic to peer
|
||||
fanout*: Table[string, HashSet[string]] # fanout - topic to peer
|
||||
gossipsub*: Table[string, HashSet[string]] # topic to peer map of all gossipsub peers
|
||||
explicit*: Table[string, HashSet[string]] # # topic to peer map of all explicit peers
|
||||
explicitPeers*: HashSet[string] # explicit (always connected/forward) peers
|
||||
mesh*: PeerTable # peers that we send messages to when we are subscribed to the topic
|
||||
fanout*: PeerTable # peers that we send messages to when we're not subscribed to the topic
|
||||
gossipsub*: PeerTable # peers that are subscribed to a topic
|
||||
explicit*: PeerTable # directpeers that we keep alive explicitly
|
||||
explicitPeers*: HashSet[string] # explicit (always connected/forward) peers
|
||||
lastFanoutPubSub*: Table[string, Moment] # last publish time for fanout topics
|
||||
gossip*: Table[string, seq[ControlIHave]] # pending gossip
|
||||
control*: Table[string, ControlMessage] # pending control messages
|
||||
|
@ -97,6 +97,28 @@ proc init*(_: type[GossipSubParams]): GossipSubParams =
|
|||
publishThreshold: 1.0,
|
||||
)
|
||||
|
||||
func addPeer(table: var PeerTable, topic: string, peer: PubSubPeer): bool =
|
||||
# returns true if the peer was added, false if it was already in the collection
|
||||
not table.mgetOrPut(topic, initHashSet[PubSubPeer]()).containsOrIncl(peer)
|
||||
|
||||
func removePeer(table: var PeerTable, topic: string, peer: PubSubPeer) =
|
||||
table.withValue(topic, peers):
|
||||
peers[].excl(peer)
|
||||
if peers[].len == 0:
|
||||
table.del(topic)
|
||||
|
||||
func hasPeer(table: PeerTable, topic: string, peer: PubSubPeer): bool =
|
||||
(topic in table) and (peer in table[topic])
|
||||
|
||||
func peers(table: PeerTable, topic: string): int =
|
||||
if topic in table:
|
||||
table[topic].len
|
||||
else:
|
||||
0
|
||||
|
||||
func getPeers(table: Table[string, HashSet[string]], topic: string): HashSet[string] =
|
||||
table.getOrDefault(topic, initHashSet[string]())
|
||||
|
||||
method init*(g: GossipSub) =
|
||||
proc handler(conn: Connection, proto: string) {.async.} =
|
||||
## main protocol handler that gets triggered on every
|
||||
|
@ -116,141 +138,104 @@ method init*(g: GossipSub) =
|
|||
proc replenishFanout(g: GossipSub, topic: string) =
|
||||
## get fanout peers for a topic
|
||||
trace "about to replenish fanout"
|
||||
if topic notin g.fanout:
|
||||
g.fanout[topic] = initHashSet[string]()
|
||||
|
||||
if g.fanout.getOrDefault(topic).len < GossipSubDLo:
|
||||
trace "replenishing fanout", peers = g.fanout.getOrDefault(topic).len
|
||||
if topic in toSeq(g.gossipsub.keys):
|
||||
for p in g.gossipsub.getOrDefault(topic):
|
||||
if not g.fanout[topic].containsOrIncl(p):
|
||||
if g.fanout.getOrDefault(topic).len == GossipSubD:
|
||||
if g.fanout.peers(topic) < GossipSubDLo:
|
||||
trace "replenishing fanout", peers = g.fanout.peers(topic)
|
||||
if topic in g.gossipsub:
|
||||
for peer in g.gossipsub[topic]:
|
||||
if g.fanout.addPeer(topic, peer):
|
||||
if g.fanout.peers(topic) == GossipSubD:
|
||||
break
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_fanout
|
||||
.set(g.fanout.getOrDefault(topic).len.int64,
|
||||
labelValues = [topic])
|
||||
.set(g.fanout.peers(topic).int64, labelValues = [topic])
|
||||
|
||||
trace "fanout replenished with peers", peers = g.fanout.getOrDefault(topic).len
|
||||
|
||||
template moveToMeshHelper(g: GossipSub,
|
||||
topic: string,
|
||||
table: Table[string, HashSet[string]]) =
|
||||
## move peers from `table` into `mesh`
|
||||
##
|
||||
var peerIds = toSeq(table.getOrDefault(topic))
|
||||
|
||||
logScope:
|
||||
topic = topic
|
||||
meshPeers = g.mesh.getOrDefault(topic).len
|
||||
peers = peerIds.len
|
||||
|
||||
shuffle(peerIds)
|
||||
for id in peerIds:
|
||||
if g.mesh.getOrDefault(topic).len > GossipSubD:
|
||||
break
|
||||
|
||||
trace "gathering peers for mesh"
|
||||
if topic notin table:
|
||||
continue
|
||||
|
||||
trace "getting peers", topic,
|
||||
peers = peerIds.len
|
||||
|
||||
table[topic].excl(id) # always exclude
|
||||
if id in g.mesh[topic]:
|
||||
continue # we already have this peer in the mesh, try again
|
||||
|
||||
if id in g.peers:
|
||||
let p = g.peers[id]
|
||||
if p.connected:
|
||||
# send a graft message to the peer
|
||||
await p.sendGraft(@[topic])
|
||||
g.mesh[topic].incl(id)
|
||||
trace "got peer", peer = id
|
||||
trace "fanout replenished with peers", peers = g.fanout.peers(topic)
|
||||
|
||||
proc rebalanceMesh(g: GossipSub, topic: string) {.async.} =
|
||||
try:
|
||||
trace "about to rebalance mesh"
|
||||
# create a mesh topic that we're subscribing to
|
||||
if topic notin g.mesh:
|
||||
g.mesh[topic] = initHashSet[string]()
|
||||
trace "about to rebalance mesh"
|
||||
# create a mesh topic that we're subscribing to
|
||||
|
||||
if g.mesh.getOrDefault(topic).len < GossipSubDlo:
|
||||
trace "replenishing mesh", topic
|
||||
# replenish the mesh if we're below GossipSubDlo
|
||||
var
|
||||
grafts, prunes: seq[PubSubPeer]
|
||||
|
||||
# move fanout nodes first
|
||||
g.moveToMeshHelper(topic, g.fanout)
|
||||
if g.mesh.peers(topic) < GossipSubDlo:
|
||||
trace "replenishing mesh", topic, peers = g.mesh.peers(topic)
|
||||
# replenish the mesh if we're below GossipSubDlo
|
||||
var newPeers = toSeq(
|
||||
g.gossipsub.getOrDefault(topic, initHashSet[PubSubPeer]()) -
|
||||
g.mesh.getOrDefault(topic, initHashSet[PubSubPeer]())
|
||||
)
|
||||
|
||||
# move gossipsub nodes second
|
||||
g.moveToMeshHelper(topic, g.gossipsub)
|
||||
logScope:
|
||||
topic = topic
|
||||
meshPeers = g.mesh.peers(topic)
|
||||
newPeers = newPeers.len
|
||||
|
||||
if g.mesh.getOrDefault(topic).len > GossipSubDhi:
|
||||
# prune peers if we've gone over
|
||||
shuffle(newPeers)
|
||||
|
||||
# ATTN possible perf bottleneck here... score is a "red" function
|
||||
# and we call a lot of Table[] etc etc
|
||||
trace "getting peers", topic, peers = newPeers.len
|
||||
|
||||
# gather peers
|
||||
var peers = toSeq(g.mesh[topic])
|
||||
# sort peers by score
|
||||
peers.sort(proc (x, y: string): int =
|
||||
let
|
||||
peerx = g.peers[x].score()
|
||||
peery = g.peers[y].score()
|
||||
if peerx < peery: -1
|
||||
elif peerx == peery: 0
|
||||
else: 1)
|
||||
|
||||
while g.mesh[topic].len > GossipSubD:
|
||||
trace "pruning peers", peers = g.mesh[topic].len
|
||||
for peer in newPeers:
|
||||
# send a graft message to the peer
|
||||
grafts.add peer
|
||||
discard g.mesh.addPeer(topic, peer)
|
||||
trace "got peer", peer = peer.id
|
||||
|
||||
# pop a low score peer
|
||||
let
|
||||
id = peers.pop()
|
||||
g.mesh[topic].excl(id)
|
||||
if g.mesh.peers(topic) > GossipSubDhi:
|
||||
# prune peers if we've gone over
|
||||
# gather peers
|
||||
var mesh = toSeq(g.mesh[topic])
|
||||
# sort peers by score
|
||||
mesh.sort(proc (x, y: PubSubPeer): int =
|
||||
let
|
||||
peerx = x.score()
|
||||
peery = y.score()
|
||||
if peerx < peery: -1
|
||||
elif peerx == peery: 0
|
||||
else: 1)
|
||||
|
||||
# send a prune message to the peer
|
||||
let
|
||||
p = g.peers[id]
|
||||
# TODO send a set of other peers where the pruned peer can connect to reform its mesh
|
||||
await p.sendPrune(@[topic])
|
||||
trace "about to prune mesh", mesh = mesh.len
|
||||
for peer in mesh:
|
||||
if g.mesh.peers(topic) <= GossipSubD:
|
||||
break
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_gossipsub
|
||||
.set(g.gossipsub.getOrDefault(topic).len.int64,
|
||||
labelValues = [topic])
|
||||
trace "pruning peers", peers = g.mesh.peers(topic)
|
||||
# send a graft message to the peer
|
||||
g.mesh.removePeer(topic, peer)
|
||||
prunes.add(peer)
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_fanout
|
||||
.set(g.fanout.getOrDefault(topic).len.int64,
|
||||
labelValues = [topic])
|
||||
libp2p_gossipsub_peers_per_topic_gossipsub
|
||||
.set(g.gossipsub.peers(topic).int64, labelValues = [topic])
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_mesh
|
||||
.set(g.mesh.getOrDefault(topic).len.int64,
|
||||
labelValues = [topic])
|
||||
libp2p_gossipsub_peers_per_topic_fanout
|
||||
.set(g.fanout.peers(topic).int64, labelValues = [topic])
|
||||
|
||||
trace "mesh balanced, got peers", peers = g.mesh.getOrDefault(topic).len,
|
||||
topicId = topic
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
warn "exception occurred re-balancing mesh", exc = exc.msg
|
||||
libp2p_gossipsub_peers_per_topic_mesh
|
||||
.set(g.mesh.peers(topic).int64, labelValues = [topic])
|
||||
|
||||
proc dropFanoutPeers(g: GossipSub) {.async.} =
|
||||
# Send changes to peers after table updates to avoid stale state
|
||||
for p in grafts:
|
||||
await p.sendGraft(@[topic])
|
||||
for p in prunes:
|
||||
await p.sendPrune(@[topic])
|
||||
|
||||
trace "mesh balanced, got peers", peers = g.mesh.peers(topic),
|
||||
topicId = topic
|
||||
|
||||
proc dropFanoutPeers(g: GossipSub) =
|
||||
# drop peers that we haven't published to in
|
||||
# GossipSubFanoutTTL seconds
|
||||
var dropping = newSeq[string]()
|
||||
for topic, val in g.lastFanoutPubSub:
|
||||
if Moment.now > val:
|
||||
dropping.add(topic)
|
||||
let now = Moment.now()
|
||||
for topic in toSeq(g.lastFanoutPubSub.keys):
|
||||
let val = g.lastFanoutPubSub[topic]
|
||||
if now > val:
|
||||
g.fanout.del(topic)
|
||||
g.lastFanoutPubSub.del(topic)
|
||||
trace "dropping fanout topic", topic
|
||||
|
||||
for topic in dropping:
|
||||
g.lastFanoutPubSub.del(topic)
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_fanout
|
||||
.set(g.fanout.getOrDefault(topic).len.int64, labelValues = [topic])
|
||||
.set(g.fanout.peers(topic).int64, labelValues = [topic])
|
||||
|
||||
proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} =
|
||||
## gossip iHave messages to peers
|
||||
|
@ -278,22 +263,18 @@ proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} =
|
|||
trace "topic not in gossip array, skipping", topicID = topic
|
||||
continue
|
||||
|
||||
for id in allPeers:
|
||||
for peer in allPeers:
|
||||
if result.len >= GossipSubD:
|
||||
trace "got gossip peers", peers = result.len
|
||||
break
|
||||
|
||||
if allPeers.len == 0:
|
||||
trace "no peers for topic, skipping", topicID = topic
|
||||
break
|
||||
|
||||
if id in gossipPeers:
|
||||
if peer in gossipPeers:
|
||||
continue
|
||||
|
||||
if id notin result:
|
||||
result[id] = controlMsg
|
||||
if peer.id notin result:
|
||||
result[peer.id] = controlMsg
|
||||
|
||||
result[id].ihave.add(ihave)
|
||||
result[peer.id].ihave.add(ihave)
|
||||
|
||||
proc heartbeat(g: GossipSub) {.async.} =
|
||||
while g.heartbeatRunning:
|
||||
|
@ -303,7 +284,7 @@ proc heartbeat(g: GossipSub) {.async.} =
|
|||
for t in toSeq(g.topics.keys):
|
||||
await g.rebalanceMesh(t)
|
||||
|
||||
await g.dropFanoutPeers()
|
||||
g.dropFanoutPeers()
|
||||
|
||||
# replenish known topics to the fanout
|
||||
for t in toSeq(g.fanout.keys):
|
||||
|
@ -322,35 +303,38 @@ proc heartbeat(g: GossipSub) {.async.} =
|
|||
except CatchableError as exc:
|
||||
trace "exception ocurred in gossipsub heartbeat", exc = exc.msg
|
||||
|
||||
await sleepAsync(1.seconds)
|
||||
await sleepAsync(GossipSubHeartbeatInterval)
|
||||
|
||||
method handleDisconnect*(g: GossipSub, peer: PubSubPeer) =
|
||||
## handle peer disconnects
|
||||
procCall FloodSub(g).handleDisconnect(peer)
|
||||
|
||||
for t in toSeq(g.gossipsub.keys):
|
||||
g.gossipsub[t].excl(peer.id)
|
||||
g.gossipsub.removePeer(t, peer)
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_gossipsub
|
||||
.set(g.gossipsub.getOrDefault(t).len.int64, labelValues = [t])
|
||||
.set(g.gossipsub.peers(t).int64, labelValues = [t])
|
||||
|
||||
for t in toSeq(g.mesh.keys):
|
||||
g.mesh[t].excl(peer.id)
|
||||
g.mesh.removePeer(t, peer)
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_mesh
|
||||
.set(g.mesh[t].len.int64, labelValues = [t])
|
||||
.set(g.mesh.peers(t).int64, labelValues = [t])
|
||||
|
||||
for t in toSeq(g.fanout.keys):
|
||||
g.fanout[t].excl(peer.id)
|
||||
g.fanout.removePeer(t, peer)
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_fanout
|
||||
.set(g.fanout[t].len.int64, labelValues = [t])
|
||||
.set(g.fanout.peers(t).int64, labelValues = [t])
|
||||
|
||||
if peer.peerInfo.maintain:
|
||||
g.explicitPeers.excl(peer.id)
|
||||
for t in toSeq(g.explicit.keys):
|
||||
g.explicit.removePeer(t, peer)
|
||||
|
||||
g.explicitPeers.excl(peer.id)
|
||||
|
||||
method subscribePeer*(p: GossipSub,
|
||||
conn: Connection) =
|
||||
conn: Connection) =
|
||||
procCall PubSub(p).subscribePeer(conn)
|
||||
asyncCheck p.handleConn(conn, GossipSubCodec_11)
|
||||
|
||||
|
@ -358,28 +342,36 @@ method subscribeTopic*(g: GossipSub,
|
|||
topic: string,
|
||||
subscribe: bool,
|
||||
peerId: string) {.gcsafe, async.} =
|
||||
await procCall PubSub(g).subscribeTopic(topic, subscribe, peerId)
|
||||
|
||||
if topic notin g.gossipsub:
|
||||
g.gossipsub[topic] = initHashSet[string]()
|
||||
await procCall FloodSub(g).subscribeTopic(topic, subscribe, peerId)
|
||||
|
||||
let peer = g.peers.getOrDefault(peerId)
|
||||
if peer == nil:
|
||||
debug "subscribeTopic on a nil peer!"
|
||||
return
|
||||
|
||||
if subscribe:
|
||||
trace "adding subscription for topic", peer = peerId, name = topic
|
||||
# subscribe remote peer to the topic
|
||||
g.gossipsub[topic].incl(peerId)
|
||||
discard g.gossipsub.addPeer(topic, peer)
|
||||
if peerId in g.explicitPeers:
|
||||
g.explicit[topic].incl(peerId)
|
||||
discard g.explicit.addPeer(topic, peer)
|
||||
else:
|
||||
trace "removing subscription for topic", peer = peerId, name = topic
|
||||
# unsubscribe remote peer from the topic
|
||||
g.gossipsub[topic].excl(peerId)
|
||||
g.gossipsub.removePeer(topic, peer)
|
||||
g.mesh.removePeer(topic, peer)
|
||||
g.fanout.removePeer(topic, peer)
|
||||
if peerId in g.explicitPeers:
|
||||
g.explicit[topic].excl(peerId)
|
||||
g.explicit.removePeer(topic, peer)
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_mesh
|
||||
.set(g.mesh.peers(topic).int64, labelValues = [topic])
|
||||
libp2p_gossipsub_peers_per_topic_fanout
|
||||
.set(g.fanout.peers(topic).int64, labelValues = [topic])
|
||||
libp2p_gossipsub_peers_per_topic_gossipsub
|
||||
.set(g.gossipsub[topic].len.int64, labelValues = [topic])
|
||||
.set(g.gossipsub.peers(topic).int64, labelValues = [topic])
|
||||
|
||||
trace "gossip peers", peers = g.gossipsub[topic].len, topic
|
||||
trace "gossip peers", peers = g.gossipsub.peers(topic), topic
|
||||
|
||||
# also rebalance current topic if we are subbed to
|
||||
if topic in g.topics:
|
||||
|
@ -387,43 +379,46 @@ method subscribeTopic*(g: GossipSub,
|
|||
|
||||
proc handleGraft(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
grafts: seq[ControlGraft],
|
||||
respControl: var ControlMessage) =
|
||||
grafts: seq[ControlGraft]): seq[ControlPrune] =
|
||||
for graft in grafts:
|
||||
trace "processing graft message", peer = peer.id,
|
||||
topicID = graft.topicID
|
||||
let topic = graft.topicID
|
||||
trace "processing graft message", topic, peer = peer.id
|
||||
|
||||
# It is an error to GRAFT on a explicit peer
|
||||
if peer.peerInfo.maintain:
|
||||
trace "attempt to graft an explicit peer", peer=peer.id,
|
||||
topicID=graft.topicID
|
||||
# and such an attempt should be logged and rejected with a PRUNE
|
||||
respControl.prune.add(ControlPrune(topicID: graft.topicID))
|
||||
result.add(ControlPrune(topicID: graft.topicID))
|
||||
continue
|
||||
|
||||
if graft.topicID in g.topics:
|
||||
if g.mesh.len < GossipSubD:
|
||||
g.mesh[graft.topicID].incl(peer.id)
|
||||
# If they send us a graft before they send us a subscribe, what should
|
||||
# we do? For now, we add them to mesh but don't add them to gossipsub.
|
||||
if topic in g.topics:
|
||||
if g.mesh.peers(topic) < GossipSubDHi:
|
||||
# In the spec, there's no mention of DHi here, but implicitly, a
|
||||
# peer will be removed from the mesh on next rebalance, so we don't want
|
||||
# this peer to push someone else out
|
||||
if g.mesh.addPeer(topic, peer):
|
||||
g.fanout.removePeer(topic, peer)
|
||||
else:
|
||||
trace "Peer already in mesh", topic, peer = peer.id
|
||||
else:
|
||||
g.gossipsub[graft.topicID].incl(peer.id)
|
||||
result.add(ControlPrune(topicID: topic))
|
||||
else:
|
||||
respControl.prune.add(ControlPrune(topicID: graft.topicID))
|
||||
result.add(ControlPrune(topicID: topic))
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_mesh
|
||||
.set(g.mesh[graft.topicID].len.int64, labelValues = [graft.topicID])
|
||||
|
||||
libp2p_gossipsub_peers_per_topic_gossipsub
|
||||
.set(g.gossipsub[graft.topicID].len.int64, labelValues = [graft.topicID])
|
||||
.set(g.mesh.peers(topic).int64, labelValues = [topic])
|
||||
|
||||
proc handlePrune(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) =
|
||||
for prune in prunes:
|
||||
trace "processing prune message", peer = peer.id,
|
||||
topicID = prune.topicID
|
||||
|
||||
if prune.topicID in g.mesh:
|
||||
g.mesh[prune.topicID].excl(peer.id)
|
||||
libp2p_gossipsub_peers_per_topic_mesh
|
||||
.set(g.mesh[prune.topicID].len.int64, labelValues = [prune.topicID])
|
||||
g.mesh.removePeer(prune.topicID, peer)
|
||||
libp2p_gossipsub_peers_per_topic_mesh
|
||||
.set(g.mesh.peers(prune.topicID).int64, labelValues = [prune.topicID])
|
||||
|
||||
proc handleIHave(g: GossipSub,
|
||||
peer: PubSubPeer,
|
||||
|
@ -456,7 +451,7 @@ method rpcHandler*(g: GossipSub,
|
|||
|
||||
for m in rpcMsgs: # for all RPC messages
|
||||
if m.messages.len > 0: # if there are any messages
|
||||
var toSendPeers: HashSet[string]
|
||||
var toSendPeers: HashSet[PubSubPeer]
|
||||
for msg in m.messages: # for every message
|
||||
let msgId = g.msgIdProvider(msg)
|
||||
logScope: msgId
|
||||
|
@ -506,25 +501,24 @@ method rpcHandler*(g: GossipSub,
|
|||
let (published, failed) = await g.sendHelper(toSendPeers, m.messages)
|
||||
for p in failed:
|
||||
let peer = g.peers.getOrDefault(p)
|
||||
if not(isNil(peer)):
|
||||
if not isNil(peer):
|
||||
g.handleDisconnect(peer) # cleanup failed peers
|
||||
|
||||
trace "forwared message to peers", peers = published.len
|
||||
|
||||
var respControl: ControlMessage
|
||||
if m.control.isSome:
|
||||
var control: ControlMessage = m.control.get()
|
||||
let iWant: ControlIWant = g.handleIHave(peer, control.ihave)
|
||||
if iWant.messageIDs.len > 0:
|
||||
respControl.iwant.add(iWant)
|
||||
let messages: seq[Message] = g.handleIWant(peer, control.iwant)
|
||||
|
||||
g.handleGraft(peer, control.graft, respControl)
|
||||
let control = m.control.get()
|
||||
g.handlePrune(peer, control.prune)
|
||||
|
||||
respControl.iwant.add(g.handleIHave(peer, control.ihave))
|
||||
respControl.prune.add(g.handleGraft(peer, control.graft))
|
||||
|
||||
if respControl.graft.len > 0 or respControl.prune.len > 0 or
|
||||
respControl.ihave.len > 0 or respControl.iwant.len > 0:
|
||||
await peer.send(@[RPCMsg(control: some(respControl), messages: messages)])
|
||||
await peer.send(
|
||||
@[RPCMsg(control: some(respControl),
|
||||
messages: g.handleIWant(peer, control.iwant))])
|
||||
|
||||
method subscribe*(g: GossipSub,
|
||||
topic: string,
|
||||
|
@ -541,9 +535,9 @@ method unsubscribe*(g: GossipSub,
|
|||
if topic in g.mesh:
|
||||
let peers = g.mesh.getOrDefault(topic)
|
||||
g.mesh.del(topic)
|
||||
for id in peers:
|
||||
let p = g.peers[id]
|
||||
await p.sendPrune(@[topic])
|
||||
|
||||
for peer in peers:
|
||||
await peer.sendPrune(@[topic])
|
||||
|
||||
method publish*(g: GossipSub,
|
||||
topic: string,
|
||||
|
@ -554,40 +548,51 @@ method publish*(g: GossipSub,
|
|||
data = data.shortLog
|
||||
# directly copy explicit peers
|
||||
# as we will always publish to those
|
||||
var peers = g.explicitPeers
|
||||
var peers = initHashSet[PubSubPeer]()
|
||||
if topic.len <= 0: # data could be 0/empty
|
||||
return 0
|
||||
|
||||
if topic.len > 0: # data could be 0/empty
|
||||
if g.parameters.floodPublish:
|
||||
for id, peer in g.peers:
|
||||
if topic in peer.topics and
|
||||
peer.score() >= g.parameters.publishThreshold:
|
||||
debug "publish: including flood/high score peer", peer = id
|
||||
peers.incl(id)
|
||||
if g.parameters.floodPublish:
|
||||
for id, peer in g.peers:
|
||||
if topic in peer.topics and
|
||||
peer.score() >= g.parameters.publishThreshold:
|
||||
debug "publish: including flood/high score peer", peer = id
|
||||
peers.incl(peer)
|
||||
|
||||
# add always direct peers
|
||||
peers.incl(g.explicit.getOrDefault(topic))
|
||||
|
||||
if topic in g.topics: # if we're subscribed use the mesh
|
||||
peers = g.mesh.getOrDefault(topic)
|
||||
peers.incl(g.mesh.getOrDefault(topic))
|
||||
else: # not subscribed, send to fanout peers
|
||||
# try optimistically
|
||||
peers = g.fanout.getOrDefault(topic)
|
||||
peers.incl(g.fanout.getOrDefault(topic))
|
||||
if peers.len == 0:
|
||||
# ok we had nothing.. let's try replenish inline
|
||||
g.replenishFanout(topic)
|
||||
peers = g.fanout.getOrDefault(topic)
|
||||
peers.incl(g.fanout.getOrDefault(topic))
|
||||
|
||||
# even if we couldn't publish,
|
||||
# we still attempted to publish
|
||||
# on the topic, so it makes sense
|
||||
# to update the last topic publish
|
||||
# time
|
||||
g.lastFanoutPubSub[topic] = Moment.fromNow(GossipSubFanoutTTL)
|
||||
|
||||
let
|
||||
msg = Message.init(g.peerInfo, data, topic, g.sign)
|
||||
msgId = g.msgIdProvider(msg)
|
||||
|
||||
trace "created new message", msg
|
||||
|
||||
trace "publishing on topic", name = topic, peers = peers
|
||||
trace "publishing on topic",
|
||||
topic, peers = peers.len, msg = msg.shortLog()
|
||||
if msgId notin g.mcache:
|
||||
g.mcache.put(msgId, msg)
|
||||
|
||||
let (published, failed) = await g.sendHelper(peers, @[msg])
|
||||
for p in failed:
|
||||
let peer = g.peers.getOrDefault(p)
|
||||
g.handleDisconnect(peer) # cleanup failed peers
|
||||
if not isNil(peer):
|
||||
g.handleDisconnect(peer) # cleanup failed peers
|
||||
|
||||
if published.len > 0:
|
||||
libp2p_pubsub_messages_published.inc(labelValues = [topic])
|
||||
|
@ -625,9 +630,9 @@ method initPubSub*(g: GossipSub) =
|
|||
|
||||
randomize()
|
||||
g.mcache = newMCache(GossipSubHistoryGossip, GossipSubHistoryLength)
|
||||
g.mesh = initTable[string, HashSet[string]]() # meshes - topic to peer
|
||||
g.fanout = initTable[string, HashSet[string]]() # fanout - topic to peer
|
||||
g.gossipsub = initTable[string, HashSet[string]]()# topic to peer map of all gossipsub peers
|
||||
g.mesh = initTable[string, HashSet[PubSubPeer]]() # meshes - topic to peer
|
||||
g.fanout = initTable[string, HashSet[PubSubPeer]]() # fanout - topic to peer
|
||||
g.gossipsub = initTable[string, HashSet[PubSubPeer]]()# topic to peer map of all gossipsub peers
|
||||
g.lastFanoutPubSub = initTable[string, Moment]() # last publish time for fanout topics
|
||||
g.gossip = initTable[string, seq[ControlIHave]]() # pending gossip
|
||||
g.control = initTable[string, ControlMessage]() # pending control messages
|
||||
|
|
|
@ -31,6 +31,8 @@ declareCounter(libp2p_pubsub_validation_failure, "pubsub failed validated messag
|
|||
declarePublicCounter(libp2p_pubsub_messages_published, "published messages", labels = ["topic"])
|
||||
|
||||
type
|
||||
PeerTable* = Table[string, HashSet[PubSubPeer]]
|
||||
|
||||
SendRes = tuple[published: seq[string], failed: seq[string]] # keep private
|
||||
|
||||
TopicHandler* = proc(topic: string,
|
||||
|
@ -58,19 +60,28 @@ type
|
|||
cleanupLock: AsyncLock
|
||||
validators*: Table[string, HashSet[ValidatorHandler]]
|
||||
observers: ref seq[PubSubObserver] # ref as in smart_ptr
|
||||
msgIdProvider*: MsgIdProvider # Turn message into message id (not nil)
|
||||
msgIdProvider*: MsgIdProvider # Turn message into message id (not nil)
|
||||
|
||||
proc hasPeerID*(t: PeerTable, topic, peerId: string): bool =
|
||||
# unefficient but used only in tests!
|
||||
let peers = t.getOrDefault(topic)
|
||||
if peers.len == 0:
|
||||
false
|
||||
else:
|
||||
let ps = toSeq(peers)
|
||||
ps.any do (peer: PubSubPeer) -> bool:
|
||||
peer.id == peerId
|
||||
|
||||
method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} =
|
||||
## handle peer disconnects
|
||||
##
|
||||
if peer.id in p.peers:
|
||||
trace "deleting peer", peer = peer.id, stack = getStackTrace()
|
||||
p.peers[peer.id] = nil
|
||||
if not isNil(peer.peerInfo) and peer.id in p.peers:
|
||||
trace "deleting peer", peer = peer.id
|
||||
p.peers.del(peer.id)
|
||||
trace "peer disconnected", peer = peer.id
|
||||
|
||||
# metrics
|
||||
libp2p_pubsub_peers.set(p.peers.len.int64)
|
||||
trace "peer disconnected", peer = peer.id
|
||||
# metrics
|
||||
libp2p_pubsub_peers.set(p.peers.len.int64)
|
||||
|
||||
proc sendSubs*(p: PubSub,
|
||||
peer: PubSubPeer,
|
||||
|
@ -127,19 +138,22 @@ method rpcHandler*(p: PubSub,
|
|||
trace "about to subscribe to topic", topicId = s.topic
|
||||
await p.subscribeTopic(s.topic, s.subscribe, peer.id)
|
||||
|
||||
proc getPeer(p: PubSub,
|
||||
peerInfo: PeerInfo,
|
||||
proto: string): PubSubPeer =
|
||||
proc getOrCreatePeer(p: PubSub,
|
||||
peerInfo: PeerInfo,
|
||||
proto: string): PubSubPeer =
|
||||
if peerInfo.id in p.peers:
|
||||
return p.peers[peerInfo.id]
|
||||
|
||||
# create new pubsub peer
|
||||
let peer = newPubSubPeer(peerInfo, proto)
|
||||
trace "created new pubsub peer", peerId = peer.id, stack = getStackTrace()
|
||||
trace "created new pubsub peer", peerId = peer.id
|
||||
|
||||
p.peers[peer.id] = peer
|
||||
peer.observers = p.observers
|
||||
|
||||
# metrics
|
||||
libp2p_pubsub_peers.set(p.peers.len.int64)
|
||||
|
||||
return peer
|
||||
|
||||
method handleConn*(p: PubSub,
|
||||
|
@ -165,7 +179,7 @@ method handleConn*(p: PubSub,
|
|||
# call pubsub rpc handler
|
||||
await p.rpcHandler(peer, msgs)
|
||||
|
||||
let peer = p.getPeer(conn.peerInfo, proto)
|
||||
let peer = p.getOrCreatePeer(conn.peerInfo, proto)
|
||||
let topics = toSeq(p.topics.keys)
|
||||
if topics.len > 0:
|
||||
await p.sendSubs(peer, topics, true)
|
||||
|
@ -184,23 +198,27 @@ method handleConn*(p: PubSub,
|
|||
|
||||
method subscribePeer*(p: PubSub, conn: Connection) {.base.} =
|
||||
if not(isNil(conn)):
|
||||
let peer = p.getPeer(conn.peerInfo, p.codec)
|
||||
let peer = p.getOrCreatePeer(conn.peerInfo, p.codec)
|
||||
trace "subscribing to peer", peerId = conn.peerInfo.id
|
||||
if not peer.connected:
|
||||
peer.conn = conn
|
||||
|
||||
method unsubscribePeer*(p: PubSub, peerInfo: PeerInfo) {.base, async.} =
|
||||
let peer = p.getPeer(peerInfo, p.codec)
|
||||
trace "unsubscribing from peer", peerId = $peerInfo
|
||||
if not(isNil(peer.conn)):
|
||||
await peer.conn.close()
|
||||
if peerInfo.id in p.peers:
|
||||
let peer = p.peers[peerInfo.id]
|
||||
|
||||
p.handleDisconnect(peer)
|
||||
trace "unsubscribing from peer", peerId = $peerInfo
|
||||
if not(isNil(peer.conn)):
|
||||
await peer.conn.close()
|
||||
|
||||
proc connected*(p: PubSub, peer: PeerInfo): bool =
|
||||
let peer = p.getPeer(peer, p.codec)
|
||||
if not(isNil(peer)):
|
||||
return peer.connected
|
||||
p.handleDisconnect(peer)
|
||||
|
||||
proc connected*(p: PubSub, peerInfo: PeerInfo): bool =
|
||||
if peerInfo.id in p.peers:
|
||||
let peer = p.peers[peerInfo.id]
|
||||
|
||||
if not(isNil(peer)):
|
||||
return peer.connected
|
||||
|
||||
method unsubscribe*(p: PubSub,
|
||||
topics: seq[TopicPair]) {.base, async.} =
|
||||
|
@ -212,6 +230,11 @@ method unsubscribe*(p: PubSub,
|
|||
if h == t.handler:
|
||||
p.topics[t.topic].handler.del(i)
|
||||
|
||||
# make sure we delete the topic if
|
||||
# no more handlers are left
|
||||
if p.topics[t.topic].handler.len <= 0:
|
||||
p.topics.del(t.topic)
|
||||
|
||||
method unsubscribe*(p: PubSub,
|
||||
topic: string,
|
||||
handler: TopicHandler): Future[void] {.base.} =
|
||||
|
@ -242,20 +265,16 @@ method subscribe*(p: PubSub,
|
|||
libp2p_pubsub_topics.inc()
|
||||
|
||||
proc sendHelper*(p: PubSub,
|
||||
sendPeers: HashSet[string],
|
||||
sendPeers: HashSet[PubSubPeer],
|
||||
msgs: seq[Message]): Future[SendRes] {.async.} =
|
||||
var sent: seq[tuple[id: string, fut: Future[void]]]
|
||||
for sendPeer in sendPeers:
|
||||
# avoid sending to self
|
||||
if sendPeer == p.peerInfo.id:
|
||||
if sendPeer.peerInfo == p.peerInfo:
|
||||
continue
|
||||
|
||||
let peer = p.peers.getOrDefault(sendPeer)
|
||||
if isNil(peer):
|
||||
continue
|
||||
|
||||
trace "sending messages to peer", peer = peer.id, msgs
|
||||
sent.add((id: peer.id, fut: peer.send(@[RPCMsg(messages: msgs)])))
|
||||
trace "sending messages to peer", peer = sendPeer.id, msgs
|
||||
sent.add((id: sendPeer.id, fut: sendPeer.send(@[RPCMsg(messages: msgs)])))
|
||||
|
||||
var published: seq[string]
|
||||
var failed: seq[string]
|
||||
|
|
|
@ -57,6 +57,35 @@ func score*(p: PubSubPeer): float64 =
|
|||
# TODO
|
||||
0.0
|
||||
|
||||
func hash*(p: PubSubPeer): Hash =
|
||||
# int is either 32/64, so intptr basically, pubsubpeer is a ref
|
||||
cast[pointer](p).hash
|
||||
|
||||
func `==`*(a, b: PubSubPeer): bool =
|
||||
# override equiality to support both nil and peerInfo comparisons
|
||||
# this in the future will allow us to recycle refs
|
||||
let
|
||||
aptr = cast[pointer](a)
|
||||
bptr = cast[pointer](b)
|
||||
if aptr == nil:
|
||||
if bptr == nil:
|
||||
true
|
||||
else:
|
||||
false
|
||||
elif bptr == nil:
|
||||
false
|
||||
else:
|
||||
if a.peerInfo == nil:
|
||||
if b.peerInfo == nil:
|
||||
true
|
||||
else:
|
||||
false
|
||||
else:
|
||||
if b.peerInfo == nil:
|
||||
false
|
||||
else:
|
||||
a.peerInfo.id == b.peerInfo.id
|
||||
|
||||
proc id*(p: PubSubPeer): string = p.peerInfo.id
|
||||
|
||||
proc connected*(p: PubSubPeer): bool =
|
||||
|
@ -176,14 +205,24 @@ proc sendMsg*(p: PubSubPeer,
|
|||
p.send(@[RPCMsg(messages: @[Message.init(p.peerInfo, data, topic, sign)])])
|
||||
|
||||
proc sendGraft*(p: PubSubPeer, topics: seq[string]) {.async.} =
|
||||
for topic in topics:
|
||||
trace "sending graft msg to peer", peer = p.id, topicID = topic
|
||||
await p.send(@[RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))])
|
||||
try:
|
||||
for topic in topics:
|
||||
trace "sending graft msg to peer", peer = p.id, topicID = topic
|
||||
await p.send(@[RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))])
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "Could not send graft", msg = exc.msg
|
||||
|
||||
proc sendPrune*(p: PubSubPeer, topics: seq[string], peers: seq[PeerInfoMsg] = @[], backoff: uint64 = 0) {.async.} =
|
||||
for topic in topics:
|
||||
trace "sending prune msg to peer", peer = p.id, topicID = topic
|
||||
await p.send(@[RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic, peers: peers, backoff: backoff)])))])
|
||||
proc sendPrune*(p: PubSubPeer, topics: seq[string]) {.async.} =
|
||||
try:
|
||||
for topic in topics:
|
||||
trace "sending prune msg to peer", peer = p.id, topicID = topic
|
||||
await p.send(@[RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))])
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
trace "Could not send prune", msg = exc.msg
|
||||
|
||||
proc `$`*(p: PubSubPeer): string =
|
||||
p.id
|
||||
|
|
|
@ -32,9 +32,7 @@ func defaultMsgIdProvider*(m: Message): string =
|
|||
byteutils.toHex(m.seqno) & m.fromPeer.pretty
|
||||
|
||||
proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} =
|
||||
var buff = initProtoBuffer()
|
||||
encodeMessage(msg, buff)
|
||||
p.privateKey.sign(PubSubPrefix & buff.buffer).tryGet().getBytes()
|
||||
p.privateKey.sign(PubSubPrefix & encodeMessage(msg)).tryGet().getBytes()
|
||||
|
||||
proc verify*(m: Message, p: PeerInfo): bool =
|
||||
if m.signature.len > 0 and m.key.len > 0:
|
||||
|
@ -42,14 +40,11 @@ proc verify*(m: Message, p: PeerInfo): bool =
|
|||
msg.signature = @[]
|
||||
msg.key = @[]
|
||||
|
||||
var buff = initProtoBuffer()
|
||||
encodeMessage(msg, buff)
|
||||
|
||||
var remote: Signature
|
||||
var key: PublicKey
|
||||
if remote.init(m.signature) and key.init(m.key):
|
||||
trace "verifying signature", remoteSignature = remote
|
||||
result = remote.verify(PubSubPrefix & buff.buffer, key)
|
||||
result = remote.verify(PubSubPrefix & encodeMessage(msg), key)
|
||||
|
||||
if result:
|
||||
libp2p_pubsub_sig_verify_success.inc()
|
||||
|
|
|
@ -14,269 +14,247 @@ import messages,
|
|||
../../../utility,
|
||||
../../../protobuf/minprotobuf
|
||||
|
||||
proc encodeGraft*(graft: ControlGraft, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, graft.topicID))
|
||||
proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, graft.topicID)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} =
|
||||
trace "decoding graft msg", buffer = pb.buffer.shortLog
|
||||
while true:
|
||||
var topic: string
|
||||
if pb.getString(1, topic) < 0:
|
||||
break
|
||||
proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, prune.topicID)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
trace "read topic field from graft msg", topicID = topic
|
||||
result.add(ControlGraft(topicID: topic))
|
||||
|
||||
proc encodePeerInfo*(info: PeerInfoMsg, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, info.peerID))
|
||||
pb.write(initProtoField(2, info.signedPeerRecord))
|
||||
|
||||
proc encodePrune*(prune: ControlPrune, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, prune.topicID))
|
||||
|
||||
var peers = initProtoBuffer()
|
||||
for p in prune.peers:
|
||||
p.encodePeerInfo(peers)
|
||||
peers.finish()
|
||||
pb.write(initProtoField(2, peers))
|
||||
|
||||
pb.write(initProtoField(3, prune.backoff))
|
||||
|
||||
proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} =
|
||||
trace "decoding prune msg"
|
||||
while true:
|
||||
var topic: string
|
||||
if pb.getString(1, topic) < 0:
|
||||
break
|
||||
|
||||
trace "read topic field from prune msg", topicID = topic
|
||||
result.add(ControlPrune(topicID: topic))
|
||||
|
||||
proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, ihave.topicID))
|
||||
proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, ihave.topicID)
|
||||
for mid in ihave.messageIDs:
|
||||
pb.write(initProtoField(2, mid))
|
||||
ipb.write(2, mid)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc decodeIHave*(pb: var ProtoBuffer): seq[ControlIHave] {.gcsafe.} =
|
||||
trace "decoding ihave msg"
|
||||
|
||||
while true:
|
||||
var control: ControlIHave
|
||||
if pb.getString(1, control.topicID) < 0:
|
||||
trace "topic field missing from ihave msg"
|
||||
break
|
||||
|
||||
trace "read topic field", topicID = control.topicID
|
||||
|
||||
while true:
|
||||
var mid: string
|
||||
if pb.getString(2, mid) < 0:
|
||||
break
|
||||
trace "read messageID field", mid = mid
|
||||
control.messageIDs.add(mid)
|
||||
|
||||
result.add(control)
|
||||
|
||||
proc encodeIWant*(iwant: ControlIWant, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) =
|
||||
var ipb = initProtoBuffer()
|
||||
for mid in iwant.messageIDs:
|
||||
pb.write(initProtoField(1, mid))
|
||||
ipb.write(1, mid)
|
||||
if len(ipb.buffer) > 0:
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc decodeIWant*(pb: var ProtoBuffer): seq[ControlIWant] {.gcsafe.} =
|
||||
trace "decoding iwant msg"
|
||||
proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) =
|
||||
var ipb = initProtoBuffer()
|
||||
for ihave in control.ihave:
|
||||
ipb.write(1, ihave)
|
||||
for iwant in control.iwant:
|
||||
ipb.write(2, iwant)
|
||||
for graft in control.graft:
|
||||
ipb.write(3, graft)
|
||||
for prune in control.prune:
|
||||
ipb.write(4, prune)
|
||||
if len(ipb.buffer) > 0:
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
var control: ControlIWant
|
||||
while true:
|
||||
var mid: string
|
||||
if pb.getString(1, mid) < 0:
|
||||
break
|
||||
control.messageIDs.add(mid)
|
||||
trace "read messageID field", mid = mid
|
||||
result.add(control)
|
||||
|
||||
proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
if control.ihave.len > 0:
|
||||
var ihave = initProtoBuffer()
|
||||
for h in control.ihave:
|
||||
h.encodeIHave(ihave)
|
||||
|
||||
# write messages to protobuf
|
||||
ihave.finish()
|
||||
pb.write(initProtoField(1, ihave))
|
||||
|
||||
if control.iwant.len > 0:
|
||||
var iwant = initProtoBuffer()
|
||||
for w in control.iwant:
|
||||
w.encodeIWant(iwant)
|
||||
|
||||
# write messages to protobuf
|
||||
iwant.finish()
|
||||
pb.write(initProtoField(2, iwant))
|
||||
|
||||
if control.graft.len > 0:
|
||||
var graft = initProtoBuffer()
|
||||
for g in control.graft:
|
||||
g.encodeGraft(graft)
|
||||
|
||||
# write messages to protobuf
|
||||
graft.finish()
|
||||
pb.write(initProtoField(3, graft))
|
||||
|
||||
if control.prune.len > 0:
|
||||
var prune = initProtoBuffer()
|
||||
for p in control.prune:
|
||||
p.encodePrune(prune)
|
||||
|
||||
# write messages to protobuf
|
||||
prune.finish()
|
||||
pb.write(initProtoField(4, prune))
|
||||
|
||||
proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} =
|
||||
trace "decoding control submessage"
|
||||
var control: ControlMessage
|
||||
while true:
|
||||
var field = pb.enterSubMessage()
|
||||
trace "processing submessage", field = field
|
||||
case field:
|
||||
of 0:
|
||||
trace "no submessage found in Control msg"
|
||||
break
|
||||
of 1:
|
||||
control.ihave &= pb.decodeIHave()
|
||||
of 2:
|
||||
control.iwant &= pb.decodeIWant()
|
||||
of 3:
|
||||
control.graft &= pb.decodeGraft()
|
||||
of 4:
|
||||
control.prune &= pb.decodePrune()
|
||||
else:
|
||||
raise newException(CatchableError, "message type not recognized")
|
||||
|
||||
if result.isNone:
|
||||
result = some(control)
|
||||
|
||||
proc encodeSubs*(subs: SubOpts, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, subs.subscribe))
|
||||
pb.write(initProtoField(2, subs.topic))
|
||||
|
||||
proc decodeSubs*(pb: var ProtoBuffer): seq[SubOpts] {.gcsafe.} =
|
||||
while true:
|
||||
var subOpt: SubOpts
|
||||
var subscr: uint
|
||||
discard pb.getVarintValue(1, subscr)
|
||||
subOpt.subscribe = cast[bool](subscr)
|
||||
trace "read subscribe field", subscribe = subOpt.subscribe
|
||||
|
||||
if pb.getString(2, subOpt.topic) < 0:
|
||||
break
|
||||
trace "read subscribe field", topicName = subOpt.topic
|
||||
|
||||
result.add(subOpt)
|
||||
|
||||
trace "got subscriptions", subscriptions = result
|
||||
|
||||
proc encodeMessage*(msg: Message, pb: var ProtoBuffer) {.gcsafe.} =
|
||||
pb.write(initProtoField(1, msg.fromPeer.getBytes()))
|
||||
pb.write(initProtoField(2, msg.data))
|
||||
pb.write(initProtoField(3, msg.seqno))
|
||||
|
||||
for t in msg.topicIDs:
|
||||
pb.write(initProtoField(4, t))
|
||||
|
||||
if msg.signature.len > 0:
|
||||
pb.write(initProtoField(5, msg.signature))
|
||||
|
||||
if msg.key.len > 0:
|
||||
pb.write(initProtoField(6, msg.key))
|
||||
proc write*(pb: var ProtoBuffer, field: int, subs: SubOpts) =
|
||||
var ipb = initProtoBuffer()
|
||||
ipb.write(1, uint64(subs.subscribe))
|
||||
ipb.write(2, subs.topic)
|
||||
ipb.finish()
|
||||
pb.write(field, ipb)
|
||||
|
||||
proc encodeMessage*(msg: Message): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, msg.fromPeer)
|
||||
pb.write(2, msg.data)
|
||||
pb.write(3, msg.seqno)
|
||||
for topic in msg.topicIDs:
|
||||
pb.write(4, topic)
|
||||
if len(msg.signature) > 0:
|
||||
pb.write(5, msg.signature)
|
||||
if len(msg.key) > 0:
|
||||
pb.write(6, msg.key)
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
proc decodeMessages*(pb: var ProtoBuffer): seq[Message] {.gcsafe.} =
|
||||
# TODO: which of this fields are really optional?
|
||||
while true:
|
||||
var msg: Message
|
||||
var fromPeer: seq[byte]
|
||||
if pb.getBytes(1, fromPeer) < 0:
|
||||
break
|
||||
try:
|
||||
msg.fromPeer = PeerID.init(fromPeer).tryGet()
|
||||
except CatchableError as err:
|
||||
debug "Invalid fromPeer in message", msg = err.msg
|
||||
break
|
||||
proc write*(pb: var ProtoBuffer, field: int, msg: Message) =
|
||||
pb.write(field, encodeMessage(msg))
|
||||
|
||||
trace "read message field", fromPeer = msg.fromPeer.pretty
|
||||
proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} =
|
||||
trace "decodeGraft: decoding message"
|
||||
var control = ControlGraft()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodeGraft: read topicId", topic_id = topicId
|
||||
else:
|
||||
trace "decodeGraft: topicId is missing"
|
||||
control
|
||||
|
||||
if pb.getBytes(2, msg.data) < 0:
|
||||
break
|
||||
trace "read message field", data = msg.data.shortLog
|
||||
proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} =
|
||||
trace "decodePrune: decoding message"
|
||||
var control = ControlPrune()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodePrune: read topicId", topic_id = topicId
|
||||
else:
|
||||
trace "decodePrune: topicId is missing"
|
||||
control
|
||||
|
||||
if pb.getBytes(3, msg.seqno) < 0:
|
||||
break
|
||||
trace "read message field", seqno = msg.seqno.shortLog
|
||||
proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} =
|
||||
trace "decodeIHave: decoding message"
|
||||
var control = ControlIHave()
|
||||
var topicId: string
|
||||
if pb.getField(1, topicId):
|
||||
control.topicId = topicId
|
||||
trace "decodeIHave: read topicId", topic_id = topicId
|
||||
else:
|
||||
trace "decodeIHave: topicId is missing"
|
||||
if pb.getRepeatedField(2, control.messageIDs):
|
||||
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIHave: no messageIDs"
|
||||
control
|
||||
|
||||
var topic: string
|
||||
while true:
|
||||
if pb.getString(4, topic) < 0:
|
||||
break
|
||||
msg.topicIDs.add(topic)
|
||||
trace "read message field", topicName = topic
|
||||
topic = ""
|
||||
proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} =
|
||||
trace "decodeIWant: decoding message"
|
||||
var control = ControlIWant()
|
||||
if pb.getRepeatedField(1, control.messageIDs):
|
||||
trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
|
||||
else:
|
||||
trace "decodeIWant: no messageIDs"
|
||||
|
||||
discard pb.getBytes(5, msg.signature)
|
||||
trace "read message field", signature = msg.signature.shortLog
|
||||
proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} =
|
||||
trace "decodeControl: decoding message"
|
||||
var buffer: seq[byte]
|
||||
if pb.getField(3, buffer):
|
||||
var control: ControlMessage
|
||||
var cpb = initProtoBuffer(buffer)
|
||||
var ihavepbs: seq[seq[byte]]
|
||||
var iwantpbs: seq[seq[byte]]
|
||||
var graftpbs: seq[seq[byte]]
|
||||
var prunepbs: seq[seq[byte]]
|
||||
|
||||
discard pb.getBytes(6, msg.key)
|
||||
trace "read message field", key = msg.key.shortLog
|
||||
discard cpb.getRepeatedField(1, ihavepbs)
|
||||
discard cpb.getRepeatedField(2, iwantpbs)
|
||||
discard cpb.getRepeatedField(3, graftpbs)
|
||||
discard cpb.getRepeatedField(4, prunepbs)
|
||||
|
||||
result.add(msg)
|
||||
for item in ihavepbs:
|
||||
control.ihave.add(decodeIHave(initProtoBuffer(item)))
|
||||
for item in iwantpbs:
|
||||
control.iwant.add(decodeIWant(initProtoBuffer(item)))
|
||||
for item in graftpbs:
|
||||
control.graft.add(decodeGraft(initProtoBuffer(item)))
|
||||
for item in prunepbs:
|
||||
control.prune.add(decodePrune(initProtoBuffer(item)))
|
||||
|
||||
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} =
|
||||
result = initProtoBuffer()
|
||||
trace "encoding msg: ", msg = msg.shortLog
|
||||
trace "decodeControl: "
|
||||
some(control)
|
||||
else:
|
||||
none[ControlMessage]()
|
||||
|
||||
if msg.subscriptions.len > 0:
|
||||
for s in msg.subscriptions:
|
||||
var subs = initProtoBuffer()
|
||||
encodeSubs(s, subs)
|
||||
# write subscriptions to protobuf
|
||||
subs.finish()
|
||||
result.write(initProtoField(1, subs))
|
||||
proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} =
|
||||
trace "decodeSubscription: decoding message"
|
||||
var subflag: uint64
|
||||
var sub = SubOpts()
|
||||
if pb.getField(1, subflag):
|
||||
sub.subscribe = bool(subflag)
|
||||
trace "decodeSubscription: read subscribe", subscribe = subflag
|
||||
else:
|
||||
trace "decodeSubscription: subscribe is missing"
|
||||
if pb.getField(2, sub.topic):
|
||||
trace "decodeSubscription: read topic", topic = sub.topic
|
||||
else:
|
||||
trace "decodeSubscription: topic is missing"
|
||||
|
||||
if msg.messages.len > 0:
|
||||
var messages = initProtoBuffer()
|
||||
for m in msg.messages:
|
||||
encodeMessage(m, messages)
|
||||
sub
|
||||
|
||||
# write messages to protobuf
|
||||
messages.finish()
|
||||
result.write(initProtoField(2, messages))
|
||||
proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} =
|
||||
trace "decodeSubscriptions: decoding message"
|
||||
var subpbs: seq[seq[byte]]
|
||||
var subs: seq[SubOpts]
|
||||
if pb.getRepeatedField(1, subpbs):
|
||||
trace "decodeSubscriptions: read subscriptions", count = len(subpbs)
|
||||
for item in subpbs:
|
||||
let sub = decodeSubscription(initProtoBuffer(item))
|
||||
subs.add(sub)
|
||||
|
||||
if msg.control.isSome:
|
||||
var control = initProtoBuffer()
|
||||
msg.control.get.encodeControl(control)
|
||||
if len(subs) == 0:
|
||||
trace "decodeSubscription: no subscriptions found"
|
||||
|
||||
# write messages to protobuf
|
||||
control.finish()
|
||||
result.write(initProtoField(3, control))
|
||||
subs
|
||||
|
||||
if result.buffer.len > 0:
|
||||
result.finish()
|
||||
proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} =
|
||||
trace "decodeMessage: decoding message"
|
||||
var msg: Message
|
||||
if pb.getField(1, msg.fromPeer):
|
||||
trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty()
|
||||
else:
|
||||
trace "decodeMessage: fromPeer is missing"
|
||||
|
||||
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} =
|
||||
if pb.getField(2, msg.data):
|
||||
trace "decodeMessage: read data", data = msg.data.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: data is missing"
|
||||
|
||||
if pb.getField(3, msg.seqno):
|
||||
trace "decodeMessage: read seqno", seqno = msg.data.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: seqno is missing"
|
||||
|
||||
if pb.getRepeatedField(4, msg.topicIDs):
|
||||
trace "decodeMessage: read topics", topic_ids = msg.topicIDs
|
||||
else:
|
||||
trace "decodeMessage: topics are missing"
|
||||
|
||||
if pb.getField(5, msg.signature):
|
||||
trace "decodeMessage: read signature", signature = msg.signature.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: signature is missing"
|
||||
|
||||
if pb.getField(6, msg.key):
|
||||
trace "decodeMessage: read public key", key = msg.key.shortLog()
|
||||
else:
|
||||
trace "decodeMessage: public key is missing"
|
||||
|
||||
msg
|
||||
|
||||
proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} =
|
||||
trace "decodeMessages: decoding message"
|
||||
var msgpbs: seq[seq[byte]]
|
||||
var msgs: seq[Message]
|
||||
if pb.getRepeatedField(2, msgpbs):
|
||||
trace "decodeMessages: read messages", count = len(msgpbs)
|
||||
for item in msgpbs:
|
||||
let msg = decodeMessage(initProtoBuffer(item))
|
||||
msgs.add(msg)
|
||||
|
||||
if len(msgs) == 0:
|
||||
trace "decodeMessages: no messages found"
|
||||
|
||||
msgs
|
||||
|
||||
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
|
||||
trace "encodeRpcMsg: encoding message", msg = msg.shortLog()
|
||||
var pb = initProtoBuffer()
|
||||
for item in msg.subscriptions:
|
||||
pb.write(1, item)
|
||||
for item in msg.messages:
|
||||
pb.write(2, item)
|
||||
if msg.control.isSome():
|
||||
pb.write(3, msg.control.get())
|
||||
if len(pb.buffer) > 0:
|
||||
pb.finish()
|
||||
result = pb
|
||||
|
||||
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg =
|
||||
trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
|
||||
var pb = initProtoBuffer(msg)
|
||||
var rpcMsg: RPCMsg
|
||||
rpcMsg.messages = pb.decodeMessages()
|
||||
rpcMsg.subscriptions = pb.decodeSubscriptions()
|
||||
rpcMsg.control = pb.decodeControl()
|
||||
|
||||
while true:
|
||||
# decode SubOpts array
|
||||
var field = pb.enterSubMessage()
|
||||
trace "processing submessage", field = field
|
||||
case field:
|
||||
of 0:
|
||||
trace "no submessage found in RPC msg"
|
||||
break
|
||||
of 1:
|
||||
result.subscriptions &= pb.decodeSubs()
|
||||
of 2:
|
||||
result.messages &= pb.decodeMessages()
|
||||
of 3:
|
||||
result.control = pb.decodeControl()
|
||||
else:
|
||||
raise newException(CatchableError, "message type not recognized")
|
||||
rpcMsg
|
||||
|
|
|
@ -12,15 +12,13 @@ import chronicles
|
|||
import bearssl
|
||||
import stew/[endians2, byteutils]
|
||||
import nimcrypto/[utils, sha2, hmac]
|
||||
import ../../stream/lpstream
|
||||
import ../../stream/[connection, streamseq]
|
||||
import ../../peerid
|
||||
import ../../peerinfo
|
||||
import ../../protobuf/minprotobuf
|
||||
import ../../utility
|
||||
import ../../stream/lpstream
|
||||
import secure,
|
||||
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf],
|
||||
../../stream/bufferstream
|
||||
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf]
|
||||
|
||||
logScope:
|
||||
topics = "noise"
|
||||
|
@ -34,7 +32,7 @@ const
|
|||
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
|
||||
|
||||
# Empty is a special value which indicates k has not yet been initialized.
|
||||
EmptyKey: ChaChaPolyKey = [0.byte, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
EmptyKey = default(ChaChaPolyKey)
|
||||
NonceMax = uint64.high - 1 # max is reserved
|
||||
NoiseSize = 32
|
||||
MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len)
|
||||
|
@ -72,7 +70,7 @@ type
|
|||
Noise* = ref object of Secure
|
||||
rng: ref BrHmacDrbgContext
|
||||
localPrivateKey: PrivateKey
|
||||
localPublicKey: PublicKey
|
||||
localPublicKey: seq[byte]
|
||||
noiseKeys: KeyPair
|
||||
commonPrologue: seq[byte]
|
||||
outgoing: bool
|
||||
|
@ -89,7 +87,7 @@ type
|
|||
# Utility
|
||||
|
||||
proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair =
|
||||
result.privateKey = Curve25519Key.random(rng).tryGet()
|
||||
result.privateKey = Curve25519Key.random(rng)
|
||||
result.publicKey = result.privateKey.public()
|
||||
|
||||
proc hashProtocol(name: string): MDigest[256] =
|
||||
|
@ -110,12 +108,11 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
|
|||
proc hasKey(cs: CipherState): bool =
|
||||
cs.k != EmptyKey
|
||||
|
||||
proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] =
|
||||
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||
var
|
||||
tag: ChaChaPolyTag
|
||||
nonce: ChaChaPolyNonce
|
||||
np = cast[ptr uint64](addr nonce[4])
|
||||
np[] = state.n
|
||||
nonce[4..<12] = toBytesLE(state.n)
|
||||
result = @data
|
||||
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad)
|
||||
inc state.n
|
||||
|
@ -124,13 +121,12 @@ proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte]
|
|||
result &= tag
|
||||
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
|
||||
|
||||
proc decryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] =
|
||||
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
|
||||
var
|
||||
tagIn = data[^ChaChaPolyTag.len..data.high].intoChaChaPolyTag
|
||||
tagOut = tagIn
|
||||
tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag
|
||||
tagOut: ChaChaPolyTag
|
||||
nonce: ChaChaPolyNonce
|
||||
np = cast[ptr uint64](addr nonce[4])
|
||||
np[] = state.n
|
||||
nonce[4..<12] = toBytesLE(state.n)
|
||||
result = data[0..(data.high - ChaChaPolyTag.len)]
|
||||
ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad)
|
||||
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
|
||||
|
@ -156,7 +152,7 @@ proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) =
|
|||
ss.cs = CipherState(k: temp_keys[1])
|
||||
trace "mixKey", key = ss.cs.k.shortLog
|
||||
|
||||
proc mixHash(ss: var SymmetricState; data: openarray[byte]) =
|
||||
proc mixHash(ss: var SymmetricState; data: openArray[byte]) =
|
||||
var ctx: sha256
|
||||
ctx.init()
|
||||
ctx.update(ss.h.data)
|
||||
|
@ -165,7 +161,7 @@ proc mixHash(ss: var SymmetricState; data: openarray[byte]) =
|
|||
trace "mixHash", hash = ss.h.data.shortLog
|
||||
|
||||
# We might use this for other handshake patterns/tokens
|
||||
proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} =
|
||||
proc mixKeyAndHash(ss: var SymmetricState; ikm: openArray[byte]) {.used.} =
|
||||
var
|
||||
temp_keys: array[3, ChaChaPolyKey]
|
||||
sha256.hkdf(ss.ck, ikm, [], temp_keys)
|
||||
|
@ -173,7 +169,7 @@ proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} =
|
|||
ss.mixHash(temp_keys[1])
|
||||
ss.cs = CipherState(k: temp_keys[2])
|
||||
|
||||
proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
|
||||
proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
|
||||
# according to spec if key is empty leave plaintext
|
||||
if ss.cs.hasKey:
|
||||
result = ss.cs.encryptWithAd(ss.h.data, data)
|
||||
|
@ -181,7 +177,7 @@ proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
|
|||
result = @data
|
||||
ss.mixHash(result)
|
||||
|
||||
proc decryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
|
||||
proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
|
||||
# according to spec if key is empty leave plaintext
|
||||
if ss.cs.hasKey:
|
||||
result = ss.cs.decryptWithAd(ss.h.data, data)
|
||||
|
@ -202,13 +198,13 @@ template write_e: untyped =
|
|||
trace "noise write e"
|
||||
# Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key).
|
||||
hs.e = genKeyPair(p.rng[])
|
||||
msg &= hs.e.publicKey
|
||||
msg.add hs.e.publicKey
|
||||
hs.ss.mixHash(hs.e.publicKey)
|
||||
|
||||
template write_s: untyped =
|
||||
trace "noise write s"
|
||||
# Appends EncryptAndHash(s.public_key) to the buffer.
|
||||
msg &= hs.ss.encryptAndHash(hs.s.publicKey)
|
||||
msg.add hs.ss.encryptAndHash(hs.s.publicKey)
|
||||
|
||||
template dh_ee: untyped =
|
||||
trace "noise dh ee"
|
||||
|
@ -244,8 +240,8 @@ template read_e: untyped =
|
|||
raise newException(NoiseHandshakeError, "Noise E, expected more data")
|
||||
|
||||
# Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
|
||||
hs.re[0..Curve25519Key.high] = msg[0..Curve25519Key.high]
|
||||
msg = msg[Curve25519Key.len..msg.high]
|
||||
hs.re[0..Curve25519Key.high] = msg.toOpenArray(0, Curve25519Key.high)
|
||||
msg.consume(Curve25519Key.len)
|
||||
hs.ss.mixHash(hs.re)
|
||||
|
||||
template read_s: untyped =
|
||||
|
@ -253,30 +249,33 @@ template read_s: untyped =
|
|||
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise.
|
||||
# Sets rs (which must be empty) to DecryptAndHash(temp).
|
||||
let
|
||||
temp =
|
||||
rsLen =
|
||||
if hs.ss.cs.hasKey:
|
||||
if msg.len < Curve25519Key.len + ChaChaPolyTag.len:
|
||||
raise newException(NoiseHandshakeError, "Noise S, expected more data")
|
||||
msg[0..Curve25519Key.high + ChaChaPolyTag.len]
|
||||
Curve25519Key.len + ChaChaPolyTag.len
|
||||
else:
|
||||
if msg.len < Curve25519Key.len:
|
||||
raise newException(NoiseHandshakeError, "Noise S, expected more data")
|
||||
msg[0..Curve25519Key.high]
|
||||
msg = msg[temp.len..msg.high]
|
||||
let plain = hs.ss.decryptAndHash(temp)
|
||||
hs.rs[0..Curve25519Key.high] = plain
|
||||
Curve25519Key.len
|
||||
hs.rs[0..Curve25519Key.high] =
|
||||
hs.ss.decryptAndHash(msg.toOpenArray(0, rsLen - 1))
|
||||
|
||||
msg.consume(rsLen)
|
||||
|
||||
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
|
||||
var besize: array[2, byte]
|
||||
await sconn.readExactly(addr besize[0], besize.len)
|
||||
let size = uint16.fromBytesBE(besize).int
|
||||
trace "receiveHSMessage", size
|
||||
if size == 0:
|
||||
return
|
||||
|
||||
var buffer = newSeq[byte](size)
|
||||
if buffer.len > 0:
|
||||
await sconn.readExactly(addr buffer[0], buffer.len)
|
||||
await sconn.readExactly(addr buffer[0], buffer.len)
|
||||
return buffer
|
||||
|
||||
proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
|
||||
proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] =
|
||||
var
|
||||
lesize = buf.len.uint16
|
||||
besize = lesize.toBytesBE
|
||||
|
@ -284,97 +283,106 @@ proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
|
|||
trace "sendHSMessage", size = lesize
|
||||
outbuf &= besize
|
||||
outbuf &= buf
|
||||
await sconn.write(outbuf)
|
||||
sconn.write(outbuf)
|
||||
|
||||
proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} =
|
||||
proc handshakeXXOutbound(
|
||||
p: Noise, conn: Connection,
|
||||
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
|
||||
const initiator = true
|
||||
|
||||
var
|
||||
hs = HandshakeState.init()
|
||||
p2psecret = p2pProof.buffer
|
||||
|
||||
hs.ss.mixHash(p.commonPrologue)
|
||||
hs.s = p.noiseKeys
|
||||
try:
|
||||
|
||||
# -> e
|
||||
var msg: seq[byte]
|
||||
hs.ss.mixHash(p.commonPrologue)
|
||||
hs.s = p.noiseKeys
|
||||
|
||||
write_e()
|
||||
# -> e
|
||||
var msg: StreamSeq
|
||||
|
||||
# IK might use this btw!
|
||||
msg &= hs.ss.encryptAndHash(@[])
|
||||
write_e()
|
||||
|
||||
await conn.sendHSMessage(msg)
|
||||
# IK might use this btw!
|
||||
msg.add hs.ss.encryptAndHash([])
|
||||
|
||||
# <- e, ee, s, es
|
||||
await conn.sendHSMessage(msg.data)
|
||||
|
||||
msg = await conn.receiveHSMessage()
|
||||
# <- e, ee, s, es
|
||||
|
||||
read_e()
|
||||
dh_ee()
|
||||
read_s()
|
||||
dh_es()
|
||||
msg.assign(await conn.receiveHSMessage())
|
||||
|
||||
let remoteP2psecret = hs.ss.decryptAndHash(msg)
|
||||
read_e()
|
||||
dh_ee()
|
||||
read_s()
|
||||
dh_es()
|
||||
|
||||
# -> s, se
|
||||
let remoteP2psecret = hs.ss.decryptAndHash(msg.data)
|
||||
msg.clear()
|
||||
|
||||
msg.setLen(0)
|
||||
# -> s, se
|
||||
|
||||
write_s()
|
||||
dh_se()
|
||||
write_s()
|
||||
dh_se()
|
||||
|
||||
# last payload must follow the ecrypted way of sending
|
||||
msg &= hs.ss.encryptAndHash(p2psecret)
|
||||
# last payload must follow the encrypted way of sending
|
||||
msg.add hs.ss.encryptAndHash(p2psecret)
|
||||
|
||||
await conn.sendHSMessage(msg)
|
||||
await conn.sendHSMessage(msg.data)
|
||||
|
||||
let (cs1, cs2) = hs.ss.split()
|
||||
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
||||
let (cs1, cs2) = hs.ss.split()
|
||||
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
||||
finally:
|
||||
burnMem(hs)
|
||||
|
||||
proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} =
|
||||
proc handshakeXXInbound(
|
||||
p: Noise, conn: Connection,
|
||||
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
|
||||
const initiator = false
|
||||
|
||||
var
|
||||
hs = HandshakeState.init()
|
||||
p2psecret = p2pProof.buffer
|
||||
|
||||
hs.ss.mixHash(p.commonPrologue)
|
||||
hs.s = p.noiseKeys
|
||||
try:
|
||||
hs.ss.mixHash(p.commonPrologue)
|
||||
hs.s = p.noiseKeys
|
||||
|
||||
# -> e
|
||||
# -> e
|
||||
|
||||
var msg = await conn.receiveHSMessage()
|
||||
var msg: StreamSeq
|
||||
msg.add(await conn.receiveHSMessage())
|
||||
|
||||
read_e()
|
||||
read_e()
|
||||
|
||||
# we might use this early data one day, keeping it here for clarity
|
||||
let earlyData {.used.} = hs.ss.decryptAndHash(msg)
|
||||
# we might use this early data one day, keeping it here for clarity
|
||||
let earlyData {.used.} = hs.ss.decryptAndHash(msg.data)
|
||||
|
||||
# <- e, ee, s, es
|
||||
# <- e, ee, s, es
|
||||
|
||||
msg.setLen(0)
|
||||
msg.consume(msg.len)
|
||||
|
||||
write_e()
|
||||
dh_ee()
|
||||
write_s()
|
||||
dh_es()
|
||||
write_e()
|
||||
dh_ee()
|
||||
write_s()
|
||||
dh_es()
|
||||
|
||||
msg &= hs.ss.encryptAndHash(p2psecret)
|
||||
msg.add hs.ss.encryptAndHash(p2psecret)
|
||||
|
||||
await conn.sendHSMessage(msg)
|
||||
await conn.sendHSMessage(msg.data)
|
||||
msg.clear()
|
||||
|
||||
# -> s, se
|
||||
# -> s, se
|
||||
|
||||
msg = await conn.receiveHSMessage()
|
||||
msg.add(await conn.receiveHSMessage())
|
||||
|
||||
read_s()
|
||||
dh_se()
|
||||
read_s()
|
||||
dh_se()
|
||||
|
||||
let remoteP2psecret = hs.ss.decryptAndHash(msg)
|
||||
|
||||
let (cs1, cs2) = hs.ss.split()
|
||||
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
||||
let
|
||||
remoteP2psecret = hs.ss.decryptAndHash(msg.data)
|
||||
(cs1, cs2) = hs.ss.split()
|
||||
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
|
||||
finally:
|
||||
burnMem(hs)
|
||||
|
||||
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
|
||||
while true: # Discard 0-length payloads
|
||||
|
@ -399,7 +407,8 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
|
|||
while left > 0:
|
||||
let
|
||||
chunkSize = if left > MaxPlainSize: MaxPlainSize else: left
|
||||
cipher = sconn.writeCs.encryptWithAd([], message.toOpenArray(offset, offset + chunkSize - 1))
|
||||
cipher = sconn.writeCs.encryptWithAd(
|
||||
[], message.toOpenArray(offset, offset + chunkSize - 1))
|
||||
left = left - chunkSize
|
||||
offset = offset + chunkSize
|
||||
var
|
||||
|
@ -421,65 +430,75 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
|
|||
|
||||
var
|
||||
libp2pProof = initProtoBuffer()
|
||||
libp2pProof.write(initProtoField(1, p.localPublicKey.getBytes.tryGet()))
|
||||
libp2pProof.write(initProtoField(2, signedPayload.getBytes()))
|
||||
libp2pProof.write(1, p.localPublicKey)
|
||||
libp2pProof.write(2, signedPayload.getBytes())
|
||||
# data field also there but not used!
|
||||
libp2pProof.finish()
|
||||
|
||||
let handshakeRes =
|
||||
var handshakeRes =
|
||||
if initiator:
|
||||
await handshakeXXOutbound(p, conn, libp2pProof)
|
||||
await handshakeXXOutbound(p, conn, libp2pProof.buffer)
|
||||
else:
|
||||
await handshakeXXInbound(p, conn, libp2pProof)
|
||||
await handshakeXXInbound(p, conn, libp2pProof.buffer)
|
||||
|
||||
var
|
||||
remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret)
|
||||
remotePubKey: PublicKey
|
||||
remotePubKeyBytes: seq[byte]
|
||||
remoteSig: Signature
|
||||
remoteSigBytes: seq[byte]
|
||||
var secure = try:
|
||||
var
|
||||
remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret)
|
||||
remotePubKey: PublicKey
|
||||
remotePubKeyBytes: seq[byte]
|
||||
remoteSig: Signature
|
||||
remoteSigBytes: seq[byte]
|
||||
|
||||
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0:
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0:
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if not(remoteProof.getField(1, remotePubKeyBytes)):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if not(remoteProof.getField(2, remoteSigBytes)):
|
||||
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
|
||||
if not remotePubKey.init(remotePubKeyBytes):
|
||||
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if not remoteSig.init(remoteSigBytes):
|
||||
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if not remotePubKey.init(remotePubKeyBytes):
|
||||
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
if not remoteSig.init(remoteSigBytes):
|
||||
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
|
||||
|
||||
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
|
||||
if not remoteSig.verify(verifyPayload, remotePubKey):
|
||||
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
|
||||
else:
|
||||
trace "Remote signature verified", peer = $conn
|
||||
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
|
||||
if not remoteSig.verify(verifyPayload, remotePubKey):
|
||||
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
|
||||
else:
|
||||
trace "Remote signature verified", peer = $conn
|
||||
|
||||
if initiator and not isNil(conn.peerInfo):
|
||||
let pid = PeerID.init(remotePubKey)
|
||||
if not conn.peerInfo.peerId.validate():
|
||||
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
|
||||
if pid.isErr or pid.get() != conn.peerInfo.peerId:
|
||||
var
|
||||
failedKey: PublicKey
|
||||
discard extractPublicKey(conn.peerInfo.peerId, failedKey)
|
||||
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
|
||||
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
|
||||
if initiator and not isNil(conn.peerInfo):
|
||||
let pid = PeerID.init(remotePubKey)
|
||||
if not conn.peerInfo.peerId.validate():
|
||||
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
|
||||
if pid.isErr or pid.get() != conn.peerInfo.peerId:
|
||||
var
|
||||
failedKey: PublicKey
|
||||
discard extractPublicKey(conn.peerInfo.peerId, failedKey)
|
||||
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
|
||||
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
|
||||
|
||||
var secure = NoiseConnection.init(conn,
|
||||
PeerInfo.init(remotePubKey),
|
||||
conn.observedAddr)
|
||||
if initiator:
|
||||
secure.readCs = handshakeRes.cs2
|
||||
secure.writeCs = handshakeRes.cs1
|
||||
else:
|
||||
secure.readCs = handshakeRes.cs1
|
||||
secure.writeCs = handshakeRes.cs2
|
||||
var tmp = NoiseConnection.init(
|
||||
conn, PeerInfo.init(remotePubKey), conn.observedAddr)
|
||||
|
||||
if initiator:
|
||||
tmp.readCs = handshakeRes.cs2
|
||||
tmp.writeCs = handshakeRes.cs1
|
||||
else:
|
||||
tmp.readCs = handshakeRes.cs1
|
||||
tmp.writeCs = handshakeRes.cs2
|
||||
tmp
|
||||
finally:
|
||||
burnMem(handshakeRes)
|
||||
|
||||
trace "Noise handshake completed!", initiator, peer = $secure.peerInfo
|
||||
|
||||
return secure
|
||||
|
||||
method close*(s: NoiseConnection) {.async.} =
|
||||
await procCall SecureConn(s).close()
|
||||
|
||||
burnMem(s.readCs)
|
||||
burnMem(s.writeCs)
|
||||
|
||||
method init*(p: Noise) {.gcsafe.} =
|
||||
procCall Secure(p).init()
|
||||
p.codec = NoiseCodec
|
||||
|
@ -491,7 +510,7 @@ proc newNoise*(
|
|||
rng: rng,
|
||||
outgoing: outgoing,
|
||||
localPrivateKey: privateKey,
|
||||
localPublicKey: privateKey.getKey().tryGet(),
|
||||
localPublicKey: privateKey.getKey().tryGet().getBytes().tryGet(),
|
||||
noiseKeys: genKeyPair(rng[]),
|
||||
commonPrologue: commonPrologue,
|
||||
)
|
||||
|
|
|
@ -44,11 +44,11 @@ method initStream*(s: SecureConn) =
|
|||
procCall Connection(s).initStream()
|
||||
|
||||
method close*(s: SecureConn) {.async.} =
|
||||
await procCall Connection(s).close()
|
||||
|
||||
if not(isNil(s.stream)):
|
||||
await s.stream.close()
|
||||
|
||||
await procCall Connection(s).close()
|
||||
|
||||
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
|
||||
doAssert(false, "Not implemented!")
|
||||
|
||||
|
@ -61,10 +61,9 @@ proc handleConn*(s: Secure,
|
|||
conn: Connection,
|
||||
initiator: bool): Future[Connection] {.async, gcsafe.} =
|
||||
var sconn = await s.handshake(conn, initiator)
|
||||
|
||||
conn.closeEvent.wait()
|
||||
.addCallback do(udata: pointer = nil):
|
||||
if not(isNil(sconn)):
|
||||
if not isNil(sconn):
|
||||
conn.closeEvent.wait()
|
||||
.addCallback do(udata: pointer = nil):
|
||||
asyncCheck sconn.close()
|
||||
|
||||
return sconn
|
||||
|
|
|
@ -198,6 +198,15 @@ method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
|||
await s.dataReadEvent.wait()
|
||||
s.dataReadEvent.clear()
|
||||
|
||||
proc drainBuffer*(s: BufferStream) {.async.} =
|
||||
## wait for all data in the buffer to be consumed
|
||||
##
|
||||
|
||||
trace "draining buffer", len = s.len
|
||||
while s.len > 0:
|
||||
await s.dataReadEvent.wait()
|
||||
s.dataReadEvent.clear()
|
||||
|
||||
method readOnce*(s: BufferStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
|
|
|
@ -7,9 +7,8 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import oids
|
||||
import chronos, chronicles
|
||||
import connection, ../utility
|
||||
import connection
|
||||
|
||||
logScope:
|
||||
topics = "chronosstream"
|
||||
|
@ -75,15 +74,10 @@ method close*(s: ChronosStream) {.async.} =
|
|||
if not s.isClosed:
|
||||
trace "shutting down chronos stream", address = $s.client.remoteAddress(),
|
||||
oid = s.oid
|
||||
|
||||
# TODO: the sequence here matters
|
||||
# don't move it after the connections
|
||||
# close bellow
|
||||
if not s.client.closed():
|
||||
await s.client.closeWait()
|
||||
|
||||
await procCall Connection(s).close()
|
||||
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as exc:
|
||||
|
|
|
@ -102,22 +102,32 @@ method readOnce*(s: LPStream,
|
|||
doAssert(false, "not implemented!")
|
||||
|
||||
proc readExactly*(s: LPStream,
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.async.} =
|
||||
pbytes: pointer,
|
||||
nbytes: int):
|
||||
Future[void] {.async.} =
|
||||
|
||||
if s.atEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
logScope:
|
||||
nbytes = nbytes
|
||||
obName = s.objName
|
||||
stack = getStackTrace()
|
||||
oid = $s.oid
|
||||
|
||||
var pbuffer = cast[ptr UncheckedArray[byte]](pbytes)
|
||||
var read = 0
|
||||
while read < nbytes and not(s.atEof()):
|
||||
read += await s.readOnce(addr pbuffer[read], nbytes - read)
|
||||
|
||||
if read < nbytes:
|
||||
trace "incomplete data received", read
|
||||
raise newLPStreamIncompleteError()
|
||||
|
||||
proc readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] {.async, deprecated: "todo".} =
|
||||
proc readLine*(s: LPStream,
|
||||
limit = 0,
|
||||
sep = "\r\n"): Future[string]
|
||||
{.async, deprecated: "todo".} =
|
||||
# TODO replace with something that exploits buffering better
|
||||
var lim = if limit <= 0: -1 else: limit
|
||||
var state = 0
|
||||
|
|
|
@ -61,6 +61,11 @@ template data*(v: StreamSeq): openArray[byte] =
|
|||
# TODO a double-hash comment here breaks compile (!)
|
||||
v.buf.toOpenArray(v.rpos, v.wpos - 1)
|
||||
|
||||
template toOpenArray*(v: StreamSeq, b, e: int): openArray[byte] =
|
||||
# Data that is ready to be consumed
|
||||
# TODO a double-hash comment here breaks compile (!)
|
||||
v.buf.toOpenArray(v.rpos + b, v.rpos + e - b)
|
||||
|
||||
func consume*(v: var StreamSeq, n: int) =
|
||||
## Mark `n` bytes that were returned via `data` as consumed
|
||||
v.rpos += n
|
||||
|
@ -71,3 +76,10 @@ func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int =
|
|||
copyMem(addr buf[0], addr v.buf[v.rpos], bytes)
|
||||
v.consume(bytes)
|
||||
bytes
|
||||
|
||||
func clear*(v: var StreamSeq) =
|
||||
v.consume(v.len)
|
||||
|
||||
func assign*(v: var StreamSeq, buf: openArray[byte]) =
|
||||
v.clear()
|
||||
v.add(buf)
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
import tables,
|
||||
sequtils,
|
||||
options,
|
||||
strformat,
|
||||
sets,
|
||||
algorithm,
|
||||
oids
|
||||
|
@ -20,19 +19,17 @@ import chronos,
|
|||
metrics
|
||||
|
||||
import stream/connection,
|
||||
stream/chronosstream,
|
||||
transports/transport,
|
||||
multistream,
|
||||
multiaddress,
|
||||
protocols/protocol,
|
||||
protocols/secure/secure,
|
||||
protocols/secure/plaintext, # for plain text
|
||||
peerinfo,
|
||||
protocols/identify,
|
||||
protocols/pubsub/pubsub,
|
||||
muxers/muxer,
|
||||
errors,
|
||||
peerid
|
||||
peerid,
|
||||
errors
|
||||
|
||||
logScope:
|
||||
topics = "switch"
|
||||
|
@ -307,12 +304,11 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
|
|||
conn.peerInfo.close()
|
||||
finally:
|
||||
await conn.close()
|
||||
libp2p_peers.set(s.connections.len.int64)
|
||||
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
|
||||
libp2p_peers.set(s.connections.len.int64)
|
||||
|
||||
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
|
||||
let connections = s.connections.getOrDefault(peer.id)
|
||||
for connHolder in connections:
|
||||
|
|
|
@ -16,9 +16,6 @@ import transport,
|
|||
../stream/connection,
|
||||
../stream/chronosstream
|
||||
|
||||
when chronicles.enabledLogLevel == LogLevel.TRACE:
|
||||
import oids
|
||||
|
||||
logScope:
|
||||
topics = "tcptransport"
|
||||
|
||||
|
@ -74,7 +71,7 @@ proc connHandler*(t: TcpTransport,
|
|||
proc cleanup() {.async.} =
|
||||
try:
|
||||
await client.join()
|
||||
trace "cleaning up client", addrs = client.remoteAddress, connoid = conn.oid
|
||||
trace "cleaning up client", addrs = $client.remoteAddress, connoid = conn.oid
|
||||
if not(isNil(conn)):
|
||||
await conn.close()
|
||||
t.clients.keepItIf(it != client)
|
||||
|
|
|
@ -7,12 +7,11 @@
|
|||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import sequtils, tables
|
||||
import sequtils
|
||||
import chronos, chronicles
|
||||
import ../stream/connection,
|
||||
../multiaddress,
|
||||
../multicodec,
|
||||
../errors
|
||||
../multicodec
|
||||
|
||||
type
|
||||
ConnHandler* = proc (conn: Connection): Future[void] {.gcsafe.}
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
{.used.}
|
||||
|
||||
import unittest, sequtils, options, tables, sets
|
||||
import unittest, sequtils, options, tables
|
||||
import chronos, stew/byteutils
|
||||
import utils,
|
||||
../../libp2p/[errors,
|
||||
|
@ -18,8 +18,7 @@ import utils,
|
|||
crypto/crypto,
|
||||
protocols/pubsub/pubsub,
|
||||
protocols/pubsub/floodsub,
|
||||
protocols/pubsub/rpc/messages,
|
||||
protocols/pubsub/rpc/message]
|
||||
protocols/pubsub/rpc/messages]
|
||||
|
||||
import ../helpers
|
||||
|
||||
|
@ -29,7 +28,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
|
|||
var ceil = 15
|
||||
let fsub = cast[FloodSub](sender.pubSub.get())
|
||||
while not fsub.floodsub.hasKey(key) or
|
||||
not fsub.floodsub[key].contains(receiver.peerInfo.id):
|
||||
not fsub.floodsub.hasPeerID(key, receiver.peerInfo.id):
|
||||
await sleepAsync(100.millis)
|
||||
dec ceil
|
||||
doAssert(ceil > 0, "waitSub timeout!")
|
||||
|
|
|
@ -29,17 +29,19 @@ suite "GossipSub internal":
|
|||
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo())
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.mesh[topic] = initHashSet[string]()
|
||||
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
|
||||
|
||||
var conns = newSeq[Connection]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
|
||||
for i in 0..<15:
|
||||
let conn = newBufferStream(noop)
|
||||
conns &= conn
|
||||
let peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].conn = conn
|
||||
gossipSub.mesh[topic].incl(peerInfo.id)
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.conn = conn
|
||||
gossipSub.peers[peerInfo.id] = peer
|
||||
gossipSub.mesh[topic].incl(peer)
|
||||
|
||||
check gossipSub.peers.len == 15
|
||||
await gossipSub.rebalanceMesh(topic)
|
||||
|
@ -57,18 +59,20 @@ suite "GossipSub internal":
|
|||
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo())
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.mesh[topic] = initHashSet[string]()
|
||||
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
|
||||
gossipSub.topics[topic] = Topic() # has to be in topics to rebalance
|
||||
|
||||
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
|
||||
var conns = newSeq[Connection]()
|
||||
for i in 0..<15:
|
||||
let conn = newBufferStream(noop)
|
||||
conns &= conn
|
||||
let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].conn = conn
|
||||
gossipSub.mesh[topic].incl(peerInfo.id)
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.conn = conn
|
||||
gossipSub.peers[peerInfo.id] = peer
|
||||
gossipSub.mesh[topic].incl(peer)
|
||||
|
||||
check gossipSub.mesh[topic].len == 15
|
||||
await gossipSub.rebalanceMesh(topic)
|
||||
|
@ -89,7 +93,7 @@ suite "GossipSub internal":
|
|||
discard
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.gossipsub[topic] = initHashSet[string]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
|
||||
|
||||
var conns = newSeq[Connection]()
|
||||
for i in 0..<15:
|
||||
|
@ -97,10 +101,9 @@ suite "GossipSub internal":
|
|||
conns &= conn
|
||||
var peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
gossipSub.peers[peerInfo.id].topics &= topic
|
||||
gossipSub.gossipsub[topic].incl(peerInfo.id)
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
gossipSub.gossipsub[topic].incl(peer)
|
||||
|
||||
check gossipSub.gossipsub[topic].len == 15
|
||||
gossipSub.replenishFanout(topic)
|
||||
|
@ -121,7 +124,7 @@ suite "GossipSub internal":
|
|||
discard
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.fanout[topic] = initHashSet[string]()
|
||||
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
|
||||
gossipSub.lastFanoutPubSub[topic] = Moment.fromNow(1.millis)
|
||||
await sleepAsync(5.millis) # allow the topic to expire
|
||||
|
||||
|
@ -131,13 +134,13 @@ suite "GossipSub internal":
|
|||
conns &= conn
|
||||
let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
gossipSub.fanout[topic].incl(peerInfo.id)
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
gossipSub.fanout[topic].incl(peer)
|
||||
|
||||
check gossipSub.fanout[topic].len == GossipSubD
|
||||
|
||||
await gossipSub.dropFanoutPeers()
|
||||
gossipSub.dropFanoutPeers()
|
||||
check topic notin gossipSub.fanout
|
||||
|
||||
await allFuturesThrowing(conns.mapIt(it.close()))
|
||||
|
@ -156,8 +159,8 @@ suite "GossipSub internal":
|
|||
|
||||
let topic1 = "foobar1"
|
||||
let topic2 = "foobar2"
|
||||
gossipSub.fanout[topic1] = initHashSet[string]()
|
||||
gossipSub.fanout[topic2] = initHashSet[string]()
|
||||
gossipSub.fanout[topic1] = initHashSet[PubSubPeer]()
|
||||
gossipSub.fanout[topic2] = initHashSet[PubSubPeer]()
|
||||
gossipSub.lastFanoutPubSub[topic1] = Moment.fromNow(1.millis)
|
||||
gossipSub.lastFanoutPubSub[topic2] = Moment.fromNow(1.minutes)
|
||||
await sleepAsync(5.millis) # allow the topic to expire
|
||||
|
@ -168,15 +171,15 @@ suite "GossipSub internal":
|
|||
conns &= conn
|
||||
let peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
gossipSub.fanout[topic1].incl(peerInfo.id)
|
||||
gossipSub.fanout[topic2].incl(peerInfo.id)
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
gossipSub.fanout[topic1].incl(peer)
|
||||
gossipSub.fanout[topic2].incl(peer)
|
||||
|
||||
check gossipSub.fanout[topic1].len == GossipSubD
|
||||
check gossipSub.fanout[topic2].len == GossipSubD
|
||||
|
||||
await gossipSub.dropFanoutPeers()
|
||||
gossipSub.dropFanoutPeers()
|
||||
check topic1 notin gossipSub.fanout
|
||||
check topic2 in gossipSub.fanout
|
||||
|
||||
|
@ -195,9 +198,9 @@ suite "GossipSub internal":
|
|||
discard
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.mesh[topic] = initHashSet[string]()
|
||||
gossipSub.fanout[topic] = initHashSet[string]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[string]()
|
||||
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
|
||||
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
|
||||
var conns = newSeq[Connection]()
|
||||
|
||||
# generate mesh and fanout peers
|
||||
|
@ -206,12 +209,12 @@ suite "GossipSub internal":
|
|||
conns &= conn
|
||||
let peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
if i mod 2 == 0:
|
||||
gossipSub.fanout[topic].incl(peerInfo.id)
|
||||
gossipSub.fanout[topic].incl(peer)
|
||||
else:
|
||||
gossipSub.mesh[topic].incl(peerInfo.id)
|
||||
gossipSub.mesh[topic].incl(peer)
|
||||
|
||||
# generate gossipsub (free standing) peers
|
||||
for i in 0..<15:
|
||||
|
@ -219,9 +222,9 @@ suite "GossipSub internal":
|
|||
conns &= conn
|
||||
let peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
gossipSub.gossipsub[topic].incl(peerInfo.id)
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
gossipSub.gossipsub[topic].incl(peer)
|
||||
|
||||
# generate messages
|
||||
for i in 0..5:
|
||||
|
@ -239,8 +242,8 @@ suite "GossipSub internal":
|
|||
let peers = gossipSub.getGossipPeers()
|
||||
check peers.len == GossipSubD
|
||||
for p in peers.keys:
|
||||
check p notin gossipSub.fanout[topic]
|
||||
check p notin gossipSub.mesh[topic]
|
||||
check not gossipSub.fanout.hasPeerID(topic, p)
|
||||
check not gossipSub.mesh.hasPeerID(topic, p)
|
||||
|
||||
await allFuturesThrowing(conns.mapIt(it.close()))
|
||||
|
||||
|
@ -257,20 +260,20 @@ suite "GossipSub internal":
|
|||
discard
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.fanout[topic] = initHashSet[string]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[string]()
|
||||
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
|
||||
var conns = newSeq[Connection]()
|
||||
for i in 0..<30:
|
||||
let conn = newBufferStream(noop)
|
||||
conns &= conn
|
||||
let peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
if i mod 2 == 0:
|
||||
gossipSub.fanout[topic].incl(peerInfo.id)
|
||||
gossipSub.fanout[topic].incl(peer)
|
||||
else:
|
||||
gossipSub.gossipsub[topic].incl(peerInfo.id)
|
||||
gossipSub.gossipsub[topic].incl(peer)
|
||||
|
||||
# generate messages
|
||||
for i in 0..5:
|
||||
|
@ -299,20 +302,20 @@ suite "GossipSub internal":
|
|||
discard
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.mesh[topic] = initHashSet[string]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[string]()
|
||||
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
|
||||
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
|
||||
var conns = newSeq[Connection]()
|
||||
for i in 0..<30:
|
||||
let conn = newBufferStream(noop)
|
||||
conns &= conn
|
||||
let peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
if i mod 2 == 0:
|
||||
gossipSub.mesh[topic].incl(peerInfo.id)
|
||||
gossipSub.mesh[topic].incl(peer)
|
||||
else:
|
||||
gossipSub.gossipsub[topic].incl(peerInfo.id)
|
||||
gossipSub.gossipsub[topic].incl(peer)
|
||||
|
||||
# generate messages
|
||||
for i in 0..5:
|
||||
|
@ -341,20 +344,20 @@ suite "GossipSub internal":
|
|||
discard
|
||||
|
||||
let topic = "foobar"
|
||||
gossipSub.mesh[topic] = initHashSet[string]()
|
||||
gossipSub.fanout[topic] = initHashSet[string]()
|
||||
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
|
||||
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
|
||||
var conns = newSeq[Connection]()
|
||||
for i in 0..<30:
|
||||
let conn = newBufferStream(noop)
|
||||
conns &= conn
|
||||
let peerInfo = randomPeerInfo()
|
||||
conn.peerInfo = peerInfo
|
||||
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
|
||||
gossipSub.peers[peerInfo.id].handler = handler
|
||||
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
|
||||
peer.handler = handler
|
||||
if i mod 2 == 0:
|
||||
gossipSub.mesh[topic].incl(peerInfo.id)
|
||||
gossipSub.mesh[topic].incl(peer)
|
||||
else:
|
||||
gossipSub.fanout[topic].incl(peerInfo.id)
|
||||
gossipSub.fanout[topic].incl(peer)
|
||||
|
||||
# generate messages
|
||||
for i in 0..5:
|
||||
|
|
|
@ -32,11 +32,11 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
|
|||
var ceil = 15
|
||||
let fsub = GossipSub(sender.pubSub.get())
|
||||
while (not fsub.gossipsub.hasKey(key) or
|
||||
not fsub.gossipsub[key].contains(receiver.peerInfo.id)) and
|
||||
not fsub.gossipsub.hasPeerID(key, receiver.peerInfo.id)) and
|
||||
(not fsub.mesh.hasKey(key) or
|
||||
not fsub.mesh[key].contains(receiver.peerInfo.id)) and
|
||||
not fsub.mesh.hasPeerID(key, receiver.peerInfo.id)) and
|
||||
(not fsub.fanout.hasKey(key) or
|
||||
not fsub.fanout[key].contains(receiver.peerInfo.id)):
|
||||
not fsub.fanout.hasPeerID(key , receiver.peerInfo.id)):
|
||||
trace "waitSub sleeping..."
|
||||
await sleepAsync(1.seconds)
|
||||
dec ceil
|
||||
|
@ -192,7 +192,7 @@ suite "GossipSub":
|
|||
check:
|
||||
"foobar" in gossip2.topics
|
||||
"foobar" in gossip1.gossipsub
|
||||
gossip2.peerInfo.id in gossip1.gossipsub["foobar"]
|
||||
gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id)
|
||||
|
||||
await allFuturesThrowing(nodes.mapIt(it.stop()))
|
||||
await allFuturesThrowing(awaitters)
|
||||
|
@ -236,11 +236,11 @@ suite "GossipSub":
|
|||
"foobar" in gossip1.gossipsub
|
||||
"foobar" in gossip2.gossipsub
|
||||
|
||||
gossip2.peerInfo.id in gossip1.gossipsub["foobar"] or
|
||||
gossip2.peerInfo.id in gossip1.mesh["foobar"]
|
||||
gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) or
|
||||
gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id)
|
||||
|
||||
gossip1.peerInfo.id in gossip2.gossipsub["foobar"] or
|
||||
gossip1.peerInfo.id in gossip2.mesh["foobar"]
|
||||
gossip2.gossipsub.hasPeerID("foobar", gossip1.peerInfo.id) or
|
||||
gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.id)
|
||||
|
||||
await allFuturesThrowing(nodes.mapIt(it.stop()))
|
||||
await allFuturesThrowing(awaitters)
|
||||
|
|
|
@ -0,0 +1,677 @@
|
|||
## Nim-Libp2p
|
||||
## Copyright (c) 2018 Status Research & Development GmbH
|
||||
## Licensed under either of
|
||||
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||
## at your option.
|
||||
## This file may not be copied, modified, or distributed except according to
|
||||
## those terms.
|
||||
|
||||
import unittest
|
||||
import ../libp2p/protobuf/minprotobuf
|
||||
import stew/byteutils, strutils
|
||||
|
||||
when defined(nimHasUsed): {.used.}
|
||||
|
||||
suite "MinProtobuf test suite":
|
||||
const VarintVectors = [
|
||||
"0800", "0801", "08ffffffff07", "08ffffffff0f", "08ffffffffffffffff7f",
|
||||
"08ffffffffffffffffff01"
|
||||
]
|
||||
|
||||
const VarintValues = [
|
||||
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xFFFF_FFFF'u64,
|
||||
0x7FFF_FFFF_FFFF_FFFF'u64, 0xFFFF_FFFF_FFFF_FFFF'u64
|
||||
]
|
||||
|
||||
const Fixed32Vectors = [
|
||||
"0d00000000", "0d01000000", "0dffffff7f", "0dddccbbaa", "0dffffffff"
|
||||
]
|
||||
|
||||
const Fixed32Values = [
|
||||
0x0'u32, 0x1'u32, 0x7FFF_FFFF'u32, 0xAABB_CCDD'u32, 0xFFFF_FFFF'u32
|
||||
]
|
||||
|
||||
const Fixed64Vectors = [
|
||||
"090000000000000000", "090100000000000000", "09ffffff7f00000000",
|
||||
"09ddccbbaa00000000", "09ffffffff00000000", "09ffffffffffffff7f",
|
||||
"099988ffeeddccbbaa", "09ffffffffffffffff"
|
||||
]
|
||||
|
||||
const Fixed64Values = [
|
||||
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xAABB_CCDD'u64, 0xFFFF_FFFF'u64,
|
||||
0x7FFF_FFFF_FFFF_FFFF'u64, 0xAABB_CCDD_EEFF_8899'u64,
|
||||
0xFFFF_FFFF_FFFF_FFFF'u64
|
||||
]
|
||||
|
||||
const LengthVectors = [
|
||||
"0a00", "0a0161", "0a026162", "0a0461626364", "0a086162636465666768"
|
||||
]
|
||||
|
||||
const LengthValues = [
|
||||
"", "a", "ab", "abcd", "abcdefgh"
|
||||
]
|
||||
|
||||
## This vector values was tested with `protoc` and related proto file.
|
||||
|
||||
## syntax = "proto2";
|
||||
## message testmsg {
|
||||
## repeated uint64 d = 1 [packed=true];
|
||||
## repeated uint64 d = 2 [packed=true];
|
||||
## }
|
||||
const PackedVarintVector =
|
||||
"0a1f0001ffffffff07ffffffff0fffffffffffffffff7fffffffffffffffffff0112020001"
|
||||
## syntax = "proto2";
|
||||
## message testmsg {
|
||||
## repeated sfixed32 d = 1 [packed=true];
|
||||
## repeated sfixed32 d = 2 [packed=true];
|
||||
## }
|
||||
const PackedFixed32Vector =
|
||||
"0a140000000001000000ffffff7fddccbbaaffffffff12080000000001000000"
|
||||
## syntax = "proto2";
|
||||
## message testmsg {
|
||||
## repeated sfixed64 d = 1 [packed=true];
|
||||
## repeated sfixed64 d = 2 [packed=true];
|
||||
## }
|
||||
const PackedFixed64Vector =
|
||||
"""0a4000000000000000000100000000000000ffffff7f00000000ddccbbaa00000000
|
||||
ffffffff00000000ffffffffffffff7f9988ffeeddccbbaaffffffffffffffff1210
|
||||
00000000000000000100000000000000"""
|
||||
|
||||
proc getVarintEncodedValue(value: uint64): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getVarintDecodedValue(data: openarray[byte]): uint64 =
|
||||
var value: uint64
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
value
|
||||
|
||||
proc getFixed32EncodedValue(value: float32): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getFixed32DecodedValue(data: openarray[byte]): uint32 =
|
||||
var value: float32
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
cast[uint32](value)
|
||||
|
||||
proc getFixed64EncodedValue(value: float64): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getFixed64DecodedValue(data: openarray[byte]): uint64 =
|
||||
var value: float64
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value)
|
||||
doAssert(res)
|
||||
cast[uint64](value)
|
||||
|
||||
proc getLengthEncodedValue(value: string): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getLengthEncodedValue(value: seq[byte]): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, value)
|
||||
pb.finish()
|
||||
return pb.buffer
|
||||
|
||||
proc getLengthDecodedValue(data: openarray[byte]): string =
|
||||
var value = newString(len(data))
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer(data)
|
||||
let res = pb.getField(1, value, valueLen)
|
||||
|
||||
doAssert(res)
|
||||
value.setLen(valueLen)
|
||||
value
|
||||
|
||||
proc isFullZero[T: byte|char](data: openarray[T]): bool =
|
||||
for ch in data:
|
||||
if int(ch) != 0:
|
||||
return false
|
||||
return true
|
||||
|
||||
proc corruptHeader(data: var openarray[byte], index: int) =
|
||||
var values = [3, 4, 6]
|
||||
data[0] = data[0] and 0xF8'u8
|
||||
data[0] = data[0] or byte(values[index mod len(values)])
|
||||
|
||||
test "[varint] edge values test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
let data = getVarintEncodedValue(VarintValues[i])
|
||||
check:
|
||||
toHex(data) == VarintVectors[i]
|
||||
getVarintDecodedValue(data) == VarintValues[i]
|
||||
|
||||
test "[varint] mixing many values with same field number test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
var pb = initProtoBuffer()
|
||||
for k in 0 ..< len(VarintValues):
|
||||
let index = (i + k + 1) mod len(VarintValues)
|
||||
pb.write(1, VarintValues[index])
|
||||
pb.finish()
|
||||
check getVarintDecodedValue(pb.buffer) == VarintValues[i]
|
||||
|
||||
test "[varint] incorrect values test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
var value: uint64
|
||||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[varint] non-existent field test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
var value: uint64
|
||||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == 0'u64
|
||||
|
||||
test "[varint] corrupted header test":
|
||||
for i in 0 ..< len(VarintValues):
|
||||
for k in 0 ..< 3:
|
||||
var value: uint64
|
||||
var data = getVarintEncodedValue(VarintValues[i])
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[varint] empty buffer test":
|
||||
var value: uint64
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == 0'u64
|
||||
|
||||
test "[varint] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, VarintValues[1])
|
||||
pb1.write(1, VarintValues[2])
|
||||
pb1.write(2, VarintValues[3])
|
||||
pb1.write(1, VarintValues[4])
|
||||
pb1.write(1, VarintValues[5])
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[uint64]
|
||||
var fieldarr2: seq[uint64]
|
||||
var fieldarr3: seq[uint64]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
fieldarr1[0] == VarintValues[1]
|
||||
fieldarr1[1] == VarintValues[2]
|
||||
fieldarr1[2] == VarintValues[4]
|
||||
fieldarr1[3] == VarintValues[5]
|
||||
fieldarr2[0] == VarintValues[3]
|
||||
|
||||
test "[varint] Repeated packed field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.writePacked(1, VarintValues)
|
||||
pb1.writePacked(2, VarintValues[0 .. 1])
|
||||
pb1.finish()
|
||||
check:
|
||||
toHex(pb1.buffer) == PackedVarintVector
|
||||
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[uint64]
|
||||
var fieldarr2: seq[uint64]
|
||||
var fieldarr3: seq[uint64]
|
||||
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 6
|
||||
fieldarr1[0] == VarintValues[0]
|
||||
fieldarr1[1] == VarintValues[1]
|
||||
fieldarr1[2] == VarintValues[2]
|
||||
fieldarr1[3] == VarintValues[3]
|
||||
fieldarr1[4] == VarintValues[4]
|
||||
fieldarr1[5] == VarintValues[5]
|
||||
fieldarr2[0] == VarintValues[0]
|
||||
fieldarr2[1] == VarintValues[1]
|
||||
|
||||
test "[fixed32] edge values test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
let data = getFixed32EncodedValue(cast[float32](Fixed32Values[i]))
|
||||
check:
|
||||
toHex(data) == Fixed32Vectors[i]
|
||||
getFixed32DecodedValue(data) == Fixed32Values[i]
|
||||
|
||||
test "[fixed32] mixing many values with same field number test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
var pb = initProtoBuffer()
|
||||
for k in 0 ..< len(Fixed32Values):
|
||||
let index = (i + k + 1) mod len(Fixed32Values)
|
||||
pb.write(1, cast[float32](Fixed32Values[index]))
|
||||
pb.finish()
|
||||
check getFixed32DecodedValue(pb.buffer) == Fixed32Values[i]
|
||||
|
||||
test "[fixed32] incorrect values test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
var value: float32
|
||||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed32] non-existent field test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
var value: float32
|
||||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == float32(0)
|
||||
|
||||
test "[fixed32] corrupted header test":
|
||||
for i in 0 ..< len(Fixed32Values):
|
||||
for k in 0 ..< 3:
|
||||
var value: float32
|
||||
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed32] empty buffer test":
|
||||
var value: float32
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == float32(0)
|
||||
|
||||
test "[fixed32] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, cast[float32](Fixed32Values[0]))
|
||||
pb1.write(1, cast[float32](Fixed32Values[1]))
|
||||
pb1.write(2, cast[float32](Fixed32Values[2]))
|
||||
pb1.write(1, cast[float32](Fixed32Values[3]))
|
||||
pb1.write(1, cast[float32](Fixed32Values[4]))
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float32]
|
||||
var fieldarr2: seq[float32]
|
||||
var fieldarr3: seq[float32]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
cast[uint32](fieldarr1[0]) == Fixed64Values[0]
|
||||
cast[uint32](fieldarr1[1]) == Fixed64Values[1]
|
||||
cast[uint32](fieldarr1[2]) == Fixed64Values[3]
|
||||
cast[uint32](fieldarr1[3]) == Fixed64Values[4]
|
||||
cast[uint32](fieldarr2[0]) == Fixed64Values[2]
|
||||
|
||||
test "[fixed32] Repeated packed field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
var values = newSeq[float32](len(Fixed32Values))
|
||||
for i in 0 ..< len(values):
|
||||
values[i] = cast[float32](Fixed32Values[i])
|
||||
pb1.writePacked(1, values)
|
||||
pb1.writePacked(2, values[0 .. 1])
|
||||
pb1.finish()
|
||||
check:
|
||||
toHex(pb1.buffer) == PackedFixed32Vector
|
||||
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float32]
|
||||
var fieldarr2: seq[float32]
|
||||
var fieldarr3: seq[float32]
|
||||
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 5
|
||||
cast[uint32](fieldarr1[0]) == Fixed32Values[0]
|
||||
cast[uint32](fieldarr1[1]) == Fixed32Values[1]
|
||||
cast[uint32](fieldarr1[2]) == Fixed32Values[2]
|
||||
cast[uint32](fieldarr1[3]) == Fixed32Values[3]
|
||||
cast[uint32](fieldarr1[4]) == Fixed32Values[4]
|
||||
cast[uint32](fieldarr2[0]) == Fixed32Values[0]
|
||||
cast[uint32](fieldarr2[1]) == Fixed32Values[1]
|
||||
|
||||
test "[fixed64] edge values test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
let data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
check:
|
||||
toHex(data) == Fixed64Vectors[i]
|
||||
getFixed64DecodedValue(data) == Fixed64Values[i]
|
||||
|
||||
test "[fixed64] mixing many values with same field number test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
var pb = initProtoBuffer()
|
||||
for k in 0 ..< len(Fixed64Values):
|
||||
let index = (i + k + 1) mod len(Fixed64Values)
|
||||
pb.write(1, cast[float64](Fixed64Values[index]))
|
||||
pb.finish()
|
||||
check getFixed64DecodedValue(pb.buffer) == Fixed64Values[i]
|
||||
|
||||
test "[fixed64] incorrect values test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
var value: float32
|
||||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed64] non-existent field test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
var value: float64
|
||||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value) == false
|
||||
value == float64(0)
|
||||
|
||||
test "[fixed64] corrupted header test":
|
||||
for i in 0 ..< len(Fixed64Values):
|
||||
for k in 0 ..< 3:
|
||||
var value: float64
|
||||
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
|
||||
test "[fixed64] empty buffer test":
|
||||
var value: float64
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value) == false
|
||||
value == float64(0)
|
||||
|
||||
test "[fixed64] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, cast[float64](Fixed64Values[2]))
|
||||
pb1.write(1, cast[float64](Fixed64Values[3]))
|
||||
pb1.write(2, cast[float64](Fixed64Values[4]))
|
||||
pb1.write(1, cast[float64](Fixed64Values[5]))
|
||||
pb1.write(1, cast[float64](Fixed64Values[6]))
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float64]
|
||||
var fieldarr2: seq[float64]
|
||||
var fieldarr3: seq[float64]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
cast[uint64](fieldarr1[0]) == Fixed64Values[2]
|
||||
cast[uint64](fieldarr1[1]) == Fixed64Values[3]
|
||||
cast[uint64](fieldarr1[2]) == Fixed64Values[5]
|
||||
cast[uint64](fieldarr1[3]) == Fixed64Values[6]
|
||||
cast[uint64](fieldarr2[0]) == Fixed64Values[4]
|
||||
|
||||
test "[fixed64] Repeated packed field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
var values = newSeq[float64](len(Fixed64Values))
|
||||
for i in 0 ..< len(values):
|
||||
values[i] = cast[float64](Fixed64Values[i])
|
||||
pb1.writePacked(1, values)
|
||||
pb1.writePacked(2, values[0 .. 1])
|
||||
pb1.finish()
|
||||
let expect = PackedFixed64Vector.multiReplace(("\n", ""), (" ", ""))
|
||||
check:
|
||||
toHex(pb1.buffer) == expect
|
||||
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[float64]
|
||||
var fieldarr2: seq[float64]
|
||||
var fieldarr3: seq[float64]
|
||||
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 2
|
||||
len(fieldarr1) == 8
|
||||
cast[uint64](fieldarr1[0]) == Fixed64Values[0]
|
||||
cast[uint64](fieldarr1[1]) == Fixed64Values[1]
|
||||
cast[uint64](fieldarr1[2]) == Fixed64Values[2]
|
||||
cast[uint64](fieldarr1[3]) == Fixed64Values[3]
|
||||
cast[uint64](fieldarr1[4]) == Fixed64Values[4]
|
||||
cast[uint64](fieldarr1[5]) == Fixed64Values[5]
|
||||
cast[uint64](fieldarr1[6]) == Fixed64Values[6]
|
||||
cast[uint64](fieldarr1[7]) == Fixed64Values[7]
|
||||
cast[uint64](fieldarr2[0]) == Fixed64Values[0]
|
||||
cast[uint64](fieldarr2[1]) == Fixed64Values[1]
|
||||
|
||||
test "[length] edge values test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
let data1 = getLengthEncodedValue(LengthValues[i])
|
||||
let data2 = getLengthEncodedValue(cast[seq[byte]](LengthValues[i]))
|
||||
check:
|
||||
toHex(data1) == LengthVectors[i]
|
||||
toHex(data2) == LengthVectors[i]
|
||||
check:
|
||||
getLengthDecodedValue(data1) == LengthValues[i]
|
||||
getLengthDecodedValue(data2) == LengthValues[i]
|
||||
|
||||
test "[length] mixing many values with same field number test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
var pb1 = initProtoBuffer()
|
||||
var pb2 = initProtoBuffer()
|
||||
for k in 0 ..< len(LengthValues):
|
||||
let index = (i + k + 1) mod len(LengthValues)
|
||||
pb1.write(1, LengthValues[index])
|
||||
pb2.write(1, cast[seq[byte]](LengthValues[index]))
|
||||
pb1.finish()
|
||||
pb2.finish()
|
||||
check getLengthDecodedValue(pb1.buffer) == LengthValues[i]
|
||||
check getLengthDecodedValue(pb2.buffer) == LengthValues[i]
|
||||
|
||||
test "[length] incorrect values test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
var value = newSeq[byte](len(LengthValues[i]))
|
||||
var valueLen = 0
|
||||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
# corrupting
|
||||
data.setLen(len(data) - 1)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
|
||||
test "[length] non-existent field test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
var value = newSeq[byte](len(LengthValues[i]))
|
||||
var valueLen = 0
|
||||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(2, value, valueLen) == false
|
||||
valueLen == 0
|
||||
|
||||
test "[length] corrupted header test":
|
||||
for i in 0 ..< len(LengthValues):
|
||||
for k in 0 ..< 3:
|
||||
var value = newSeq[byte](len(LengthValues[i]))
|
||||
var valueLen = 0
|
||||
var data = getLengthEncodedValue(LengthValues[i])
|
||||
data.corruptHeader(k)
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
|
||||
test "[length] empty buffer test":
|
||||
var value = newSeq[byte](len(LengthValues[0]))
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer()
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
valueLen == 0
|
||||
|
||||
test "[length] buffer overflow test":
|
||||
for i in 1 ..< len(LengthValues):
|
||||
let data = getLengthEncodedValue(LengthValues[i])
|
||||
|
||||
var value = newString(len(LengthValues[i]) - 1)
|
||||
var valueLen = 0
|
||||
var pb = initProtoBuffer(data)
|
||||
check:
|
||||
pb.getField(1, value, valueLen) == false
|
||||
valueLen == len(LengthValues[i])
|
||||
isFullZero(value) == true
|
||||
|
||||
test "[length] mix of buffer overflow and normal fields test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, "TEST10")
|
||||
pb1.write(1, "TEST20")
|
||||
pb1.write(1, "TEST")
|
||||
pb1.write(1, "TEST30")
|
||||
pb1.write(1, "SOME")
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var value = newString(4)
|
||||
var valueLen = 0
|
||||
check:
|
||||
pb2.getField(1, value, valueLen) == true
|
||||
value == "SOME"
|
||||
|
||||
test "[length] too big message test":
|
||||
var pb1 = initProtoBuffer()
|
||||
var bigString = newString(MaxMessageSize + 1)
|
||||
|
||||
for i in 0 ..< len(bigString):
|
||||
bigString[i] = 'A'
|
||||
pb1.write(1, bigString)
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var value = newString(MaxMessageSize + 1)
|
||||
var valueLen = 0
|
||||
check:
|
||||
pb2.getField(1, value, valueLen) == false
|
||||
|
||||
test "[length] Repeated field test":
|
||||
var pb1 = initProtoBuffer()
|
||||
pb1.write(1, "TEST1")
|
||||
pb1.write(1, "TEST2")
|
||||
pb1.write(2, "TEST5")
|
||||
pb1.write(1, "TEST3")
|
||||
pb1.write(1, "TEST4")
|
||||
pb1.finish()
|
||||
var pb2 = initProtoBuffer(pb1.buffer)
|
||||
var fieldarr1: seq[seq[byte]]
|
||||
var fieldarr2: seq[seq[byte]]
|
||||
var fieldarr3: seq[seq[byte]]
|
||||
let r1 = pb2.getRepeatedField(1, fieldarr1)
|
||||
let r2 = pb2.getRepeatedField(2, fieldarr2)
|
||||
let r3 = pb2.getRepeatedField(3, fieldarr3)
|
||||
check:
|
||||
r1 == true
|
||||
r2 == true
|
||||
r3 == false
|
||||
len(fieldarr3) == 0
|
||||
len(fieldarr2) == 1
|
||||
len(fieldarr1) == 4
|
||||
cast[string](fieldarr1[0]) == "TEST1"
|
||||
cast[string](fieldarr1[1]) == "TEST2"
|
||||
cast[string](fieldarr1[2]) == "TEST3"
|
||||
cast[string](fieldarr1[3]) == "TEST4"
|
||||
cast[string](fieldarr2[0]) == "TEST5"
|
||||
|
||||
test "Different value types in one message with same field number test":
|
||||
proc getEncodedValue(): seq[byte] =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write(1, VarintValues[1])
|
||||
pb.write(2, cast[float32](Fixed32Values[1]))
|
||||
pb.write(3, cast[float64](Fixed64Values[1]))
|
||||
pb.write(4, LengthValues[1])
|
||||
|
||||
pb.write(1, VarintValues[2])
|
||||
pb.write(2, cast[float32](Fixed32Values[2]))
|
||||
pb.write(3, cast[float64](Fixed64Values[2]))
|
||||
pb.write(4, LengthValues[2])
|
||||
|
||||
pb.write(1, cast[float32](Fixed32Values[3]))
|
||||
pb.write(2, cast[float64](Fixed64Values[3]))
|
||||
pb.write(3, LengthValues[3])
|
||||
pb.write(4, VarintValues[3])
|
||||
|
||||
pb.write(1, cast[float64](Fixed64Values[4]))
|
||||
pb.write(2, LengthValues[4])
|
||||
pb.write(3, VarintValues[4])
|
||||
pb.write(4, cast[float32](Fixed32Values[4]))
|
||||
|
||||
pb.write(1, VarintValues[1])
|
||||
pb.write(2, cast[float32](Fixed32Values[1]))
|
||||
pb.write(3, cast[float64](Fixed64Values[1]))
|
||||
pb.write(4, LengthValues[1])
|
||||
pb.finish()
|
||||
pb.buffer
|
||||
|
||||
let msg = getEncodedValue()
|
||||
let pb = initProtoBuffer(msg)
|
||||
var varintValue: uint64
|
||||
var fixed32Value: float32
|
||||
var fixed64Value: float64
|
||||
var lengthValue = newString(10)
|
||||
var lengthSize: int
|
||||
|
||||
check:
|
||||
pb.getField(1, varintValue) == true
|
||||
pb.getField(2, fixed32Value) == true
|
||||
pb.getField(3, fixed64Value) == true
|
||||
pb.getField(4, lengthValue, lengthSize) == true
|
||||
|
||||
lengthValue.setLen(lengthSize)
|
||||
|
||||
check:
|
||||
varintValue == VarintValues[1]
|
||||
cast[uint32](fixed32Value) == Fixed32Values[1]
|
||||
cast[uint64](fixed64Value) == Fixed64Values[1]
|
||||
lengthValue == LengthValues[1]
|
|
@ -1,4 +1,4 @@
|
|||
import unittest, strutils, sequtils, strformat, stew/byteutils
|
||||
import unittest, strutils, strformat, stew/byteutils
|
||||
import chronos
|
||||
import ../libp2p/errors,
|
||||
../libp2p/multistream,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import testvarint,
|
||||
testminprotobuf,
|
||||
teststreamseq
|
||||
|
||||
import testrsa,
|
||||
|
|
|
@ -9,19 +9,13 @@ import ../libp2p/[errors,
|
|||
multistream,
|
||||
standard_setup,
|
||||
stream/bufferstream,
|
||||
protocols/identify,
|
||||
stream/connection,
|
||||
transports/transport,
|
||||
transports/tcptransport,
|
||||
multiaddress,
|
||||
peerinfo,
|
||||
crypto/crypto,
|
||||
protocols/protocol,
|
||||
muxers/muxer,
|
||||
muxers/mplex/mplex,
|
||||
muxers/mplex/types,
|
||||
protocols/secure/secio,
|
||||
protocols/secure/secure,
|
||||
stream/lpstream]
|
||||
import ./helpers
|
||||
|
||||
|
|
Loading…
Reference in New Issue