Merge branch 'master' into gossip-one-one

This commit is contained in:
Giovanni Petrantoni 2020-07-15 08:47:43 +09:00
commit 8078fec0f0
35 changed files with 2327 additions and 926 deletions

View File

@ -22,13 +22,14 @@ matrix:
- NPROC=2
before_install:
- export GOPATH=$HOME/go
- os: linux
arch: arm64
env:
- NPROC=6 # Worth trying more than 2 parallel jobs: https://travis-ci.community/t/no-cache-support-on-arm64/5416/8
# (also used to get a different cache key than the amd64 one)
before_install:
- export GOPATH=$HOME/go
## arm64 is very unreliable and slow, disabled for now
# - os: linux
# arch: arm64
# env:
# - NPROC=6 # Worth trying more than 2 parallel jobs: https://travis-ci.community/t/no-cache-support-on-arm64/5416/8
# # (also used to get a different cache key than the amd64 one)
# before_install:
# - export GOPATH=$HOME/go
- os: osx
env:
- NPROC=2

View File

@ -17,12 +17,15 @@ requires "nim >= 1.2.0",
"stew >= 0.1.0"
proc runTest(filename: string, verify: bool = true, sign: bool = true) =
var excstr = "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info"
var excstr = "nim c --opt:speed -d:debug --verbosity:0 --hints:off"
excstr.add(" --warning[CaseTransition]:off --warning[ObservableStores]:off --warning[LockLevel]:off")
excstr.add(" -d:libp2p_pubsub_sign=" & $sign)
excstr.add(" -d:libp2p_pubsub_verify=" & $verify)
excstr.add(" tests/" & filename)
exec excstr
if verify and sign:
# build it with TRACE and JSON logs
exec excstr & " -d:chronicles_log_level=TRACE -d:chronicles_sinks:json" & " tests/" & filename
# build it again, to run it with less verbose logs
exec excstr & " -d:chronicles_log_level=INFO -r" & " tests/" & filename
rmFile "tests/" & filename.toExe
proc buildSample(filename: string) =

View File

@ -222,8 +222,8 @@ proc toBytes*(key: PrivateKey, data: var openarray[byte]): CryptoResult[int] =
##
## Returns number of bytes (octets) needed to store private key ``key``.
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
var blen = len(msg.buffer)
if len(data) >= blen:
@ -236,8 +236,8 @@ proc toBytes*(key: PublicKey, data: var openarray[byte]): CryptoResult[int] =
##
## Returns number of bytes (octets) needed to store public key ``key``.
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
var blen = len(msg.buffer)
if len(data) >= blen and blen > 0:
@ -256,8 +256,8 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] =
## Return private key ``key`` in binary form (using libp2p's protobuf
## serialization).
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
ok(msg.buffer)
@ -265,8 +265,8 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] =
## Return public key ``key`` in binary form (using libp2p's protobuf
## serialization).
var msg = initProtoBuffer()
msg.write(initProtoField(1, cast[uint64](key.scheme)))
msg.write(initProtoField(2, ? key.getRawBytes()))
msg.write(1, uint64(key.scheme))
msg.write(2, ? key.getRawBytes())
msg.finish()
ok(msg.buffer)
@ -283,33 +283,32 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool =
var buffer: seq[byte]
if len(data) > 0:
var pb = initProtoBuffer(@data)
if pb.getVarintValue(1, id) != 0:
if pb.getBytes(2, buffer) != 0:
if cast[int8](id) in SupportedSchemesInt:
var scheme = cast[PKScheme](cast[int8](id))
when key is PrivateKey:
var nkey = PrivateKey(scheme: scheme)
else:
var nkey = PublicKey(scheme: scheme)
case scheme:
of PKScheme.RSA:
if init(nkey.rsakey, buffer).isOk:
key = nkey
return true
of PKScheme.Ed25519:
if init(nkey.edkey, buffer):
key = nkey
return true
of PKScheme.ECDSA:
if init(nkey.eckey, buffer).isOk:
key = nkey
return true
of PKScheme.Secp256k1:
if init(nkey.skkey, buffer).isOk:
key = nkey
return true
else:
return false
if pb.getField(1, id) and pb.getField(2, buffer):
if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0:
var scheme = cast[PKScheme](cast[int8](id))
when key is PrivateKey:
var nkey = PrivateKey(scheme: scheme)
else:
var nkey = PublicKey(scheme: scheme)
case scheme:
of PKScheme.RSA:
if init(nkey.rsakey, buffer).isOk:
key = nkey
return true
of PKScheme.Ed25519:
if init(nkey.edkey, buffer):
key = nkey
return true
of PKScheme.ECDSA:
if init(nkey.eckey, buffer).isOk:
key = nkey
return true
of PKScheme.Secp256k1:
if init(nkey.skkey, buffer).isOk:
key = nkey
return true
else:
return false
proc init*(sig: var Signature, data: openarray[byte]): bool =
## Initialize signature ``sig`` from raw binary form.
@ -374,6 +373,24 @@ proc init*(t: typedesc[PrivateKey], data: string): CryptoResult[PrivateKey] =
except ValueError:
err(KeyError)
proc init*(t: typedesc[PrivateKey], key: rsa.RsaPrivateKey): PrivateKey =
PrivateKey(scheme: RSA, rsakey: key)
proc init*(t: typedesc[PrivateKey], key: EdPrivateKey): PrivateKey =
PrivateKey(scheme: Ed25519, edkey: key)
proc init*(t: typedesc[PrivateKey], key: SkPrivateKey): PrivateKey =
PrivateKey(scheme: Secp256k1, skkey: key)
proc init*(t: typedesc[PrivateKey], key: ecnist.EcPrivateKey): PrivateKey =
PrivateKey(scheme: ECDSA, eckey: key)
proc init*(t: typedesc[PublicKey], key: rsa.RsaPublicKey): PublicKey =
PublicKey(scheme: RSA, rsakey: key)
proc init*(t: typedesc[PublicKey], key: EdPublicKey): PublicKey =
PublicKey(scheme: Ed25519, edkey: key)
proc init*(t: typedesc[PublicKey], key: SkPublicKey): PublicKey =
PublicKey(scheme: Secp256k1, skkey: key)
proc init*(t: typedesc[PublicKey], key: ecnist.EcPublicKey): PublicKey =
PublicKey(scheme: ECDSA, eckey: key)
proc init*(t: typedesc[PublicKey], data: string): CryptoResult[PublicKey] =
## Create new public key from libp2p's protobuf serialized hexadecimal string
## form.
@ -709,11 +726,11 @@ proc createProposal*(nonce, pubkey: openarray[byte],
## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and
## comma-delimeted list of supported hashes ``hashes``.
var msg = initProtoBuffer({WithUint32BeLength})
msg.write(initProtoField(1, nonce))
msg.write(initProtoField(2, pubkey))
msg.write(initProtoField(3, exchanges))
msg.write(initProtoField(4, ciphers))
msg.write(initProtoField(5, hashes))
msg.write(1, nonce)
msg.write(2, pubkey)
msg.write(3, exchanges)
msg.write(4, ciphers)
msg.write(5, hashes)
msg.finish()
shallowCopy(result, msg.buffer)
@ -726,19 +743,16 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte],
##
## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message)
if pb.getLengthValue(1, nonce) != -1 and
pb.getLengthValue(2, pubkey) != -1 and
pb.getLengthValue(3, exchanges) != -1 and
pb.getLengthValue(4, ciphers) != -1 and
pb.getLengthValue(5, hashes) != -1:
result = true
pb.getField(1, nonce) and pb.getField(2, pubkey) and
pb.getField(3, exchanges) and pb.getField(4, ciphers) and
pb.getField(5, hashes)
proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] =
## Create SecIO exchange message using ephemeral public key ``epubkey`` and
## signature of proposal blocks ``signature``.
var msg = initProtoBuffer({WithUint32BeLength})
msg.write(initProtoField(1, epubkey))
msg.write(initProtoField(2, signature))
msg.write(1, epubkey)
msg.write(2, signature)
msg.finish()
shallowCopy(result, msg.buffer)
@ -749,9 +763,7 @@ proc decodeExchange*(message: seq[byte],
##
## Procedure returns ``true`` on success and ``false`` on error.
var pb = initProtoBuffer(message)
if pb.getLengthValue(1, pubkey) != -1 and
pb.getLengthValue(2, signature) != -1:
result = true
pb.getField(1, pubkey) and pb.getField(2, signature)
## Serialization/Deserialization helpers
@ -770,22 +782,27 @@ proc write*(vb: var VBuffer, sig: PrivateKey) {.
## Write Signature value ``sig`` to buffer ``vb``.
vb.writeSeq(sig.getBytes().tryGet())
proc initProtoField*(index: int, pubkey: PublicKey): ProtoField {.
raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PublicKey ``pubkey``.
result = initProtoField(index, pubkey.getBytes().tryGet())
proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int,
key: T) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
write(pb, field, key.getBytes().tryGet())
proc initProtoField*(index: int, seckey: PrivateKey): ProtoField {.
raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PrivateKey ``seckey``.
result = initProtoField(index, seckey.getBytes().tryGet())
proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {.
inline, raises: [Defect, ResultError[CryptoError]].} =
write(pb, field, sig.getBytes())
proc initProtoField*(index: int, sig: Signature): ProtoField =
proc initProtoField*(index: int, key: PublicKey|PrivateKey): ProtoField {.
deprecated, raises: [Defect, ResultError[CryptoError]].} =
## Initialize ProtoField with PublicKey/PrivateKey ``key``.
result = initProtoField(index, key.getBytes().tryGet())
proc initProtoField*(index: int, sig: Signature): ProtoField {.deprecated.} =
## Initialize ProtoField with Signature ``sig``.
result = initProtoField(index, sig.getBytes())
proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
## Read ``PublicKey`` from ProtoBuf's message and validate it.
proc getValue*[T: PublicKey|PrivateKey](data: var ProtoBuffer, field: int,
value: var T): int {.deprecated.} =
## Read PublicKey/PrivateKey from ProtoBuf's message and validate it.
var buf: seq[byte]
var key: PublicKey
result = getLengthValue(data, field, buf)
@ -795,18 +812,8 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int =
else:
value = key
proc getValue*(data: var ProtoBuffer, field: int, value: var PrivateKey): int =
## Read ``PrivateKey`` from ProtoBuf's message and validate it.
var buf: seq[byte]
var key: PrivateKey
result = getLengthValue(data, field, buf)
if result > 0:
if not key.init(buf):
result = -1
else:
value = key
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {.
deprecated.} =
## Read ``Signature`` from ProtoBuf's message and validate it.
var buf: seq[byte]
var sig: Signature
@ -816,3 +823,30 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int =
result = -1
else:
value = sig
proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int,
value: var T): bool =
var buffer: seq[byte]
var key: T
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if key.init(buffer):
value = key
true
else:
false
proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool =
var buffer: seq[byte]
var sig: Signature
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if sig.init(buffer):
value = sig
true
else:
false

View File

@ -105,10 +105,13 @@ proc mulgen*(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key)
proc public*(private: Curve25519Key): Curve25519Key =
Curve25519.mulgen(result, private)
proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Result[Curve25519Key, Curve25519Error] =
proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Curve25519Key =
var res: Curve25519Key
let defaultBrEc = brEcGetDefault()
if brEcKeygen(addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) != Curve25519KeySize:
err(Curver25519GenError)
else:
ok(res)
let len = brEcKeygen(
addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519)
# Per bearssl documentation, the keygen only fails if the curve is
# unrecognised -
doAssert len == Curve25519KeySize, "Could not generate curve"
res

View File

@ -39,11 +39,11 @@ const
type
EcPrivateKey* = ref object
buffer*: seq[byte]
buffer*: array[BR_EC_KBUF_PRIV_MAX_SIZE, byte]
key*: BrEcPrivateKey
EcPublicKey* = ref object
buffer*: seq[byte]
buffer*: array[BR_EC_KBUF_PUB_MAX_SIZE, byte]
key*: BrEcPublicKey
EcKeyPair* = object
@ -237,7 +237,6 @@ proc random*(
## secp521r1).
var ecimp = brEcGetDefault()
var res = new EcPrivateKey
res.buffer = newSeq[byte](BR_EC_KBUF_PRIV_MAX_SIZE)
if brEcKeygen(addr rng.vtable, ecimp,
addr res.key, addr res.buffer[0],
cast[cint](kind)) == 0:
@ -254,7 +253,6 @@ proc getKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] =
if seckey.key.curve in EcSupportedCurvesCint:
var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve))
var res = new EcPublicKey
res.buffer = newSeq[byte](length)
if brEcComputePublicKey(ecimp, addr res.key,
addr res.buffer[0], unsafeAddr seckey.key) == 0:
err(EcKeyIncorrectError)
@ -621,7 +619,6 @@ proc init*(key: var EcPrivateKey, data: openarray[byte]): Result[void, Asn1Error
if checkScalar(raw.toOpenArray(), curve) == 1'u32:
key = new EcPrivateKey
key.buffer = newSeq[byte](raw.length)
copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length)
key.key.x = cast[ptr cuchar](addr key.buffer[0])
key.key.xlen = raw.length
@ -681,7 +678,6 @@ proc init*(pubkey: var EcPublicKey, data: openarray[byte]): Result[void, Asn1Err
if checkPublic(raw.toOpenArray(), curve) != 0:
pubkey = new EcPublicKey
pubkey.buffer = newSeq[byte](raw.length)
copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length)
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
pubkey.key.qlen = raw.length
@ -769,7 +765,6 @@ proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool =
if checkScalar(data, curve) == 1'u32:
let length = len(data)
key = new EcPrivateKey
key.buffer = newSeq[byte](length)
copyMem(addr key.buffer[0], unsafeAddr data[0], length)
key.key.x = cast[ptr cuchar](addr key.buffer[0])
key.key.xlen = length
@ -801,7 +796,6 @@ proc initRaw*(pubkey: var EcPublicKey, data: openarray[byte]): bool =
if checkPublic(data, curve) != 0:
let length = len(data)
pubkey = new EcPublicKey
pubkey.buffer = newSeq[byte](length)
copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length)
pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0])
pubkey.key.qlen = length

View File

@ -14,9 +14,10 @@
import nativesockets
import tables, strutils, stew/shims/net
import chronos
import multicodec, multihash, multibase, transcoder, vbuffer, peerid
import multicodec, multihash, multibase, transcoder, vbuffer, peerid,
protobuf/minprotobuf
import stew/[base58, base32, endians2, results]
export results
export results, minprotobuf, vbuffer
type
MAKind* = enum
@ -477,7 +478,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] =
else:
ok($(proto.mcodec))
proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] =
proc protoArgument*(ma: MultiAddress,
value: var openarray[byte]): MaResult[int] =
## Returns MultiAddress ``ma`` protocol argument value.
##
## If current MultiAddress do not have argument value, then result will be
@ -496,8 +498,8 @@ proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int]
var res: int
if proto.kind == Fixed:
res = proto.size
if len(value) >= res and
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
if len(value) >= res and
vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size:
err("multiaddress: Decoding protocol error")
else:
ok(res)
@ -580,7 +582,8 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] =
let proto = CodeAddresses.getOrDefault(MultiCodec(header))
if proto.kind == None:
yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'")
yield err(MaResult[MultiAddress], "Unsupported protocol '" &
$header & "'")
elif proto.kind == Fixed:
data.setLen(proto.size)
@ -609,7 +612,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} =
return ok(true)
ok(false)
proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
proc `[]`*(ma: MultiAddress,
codec: MultiCodec): MaResult[MultiAddress] {.inline.} =
## Returns partial MultiAddress with MultiCodec ``codec`` and present in
## MultiAddress ``ma``.
for item in ma.items:
@ -634,7 +638,8 @@ proc toString*(value: MultiAddress): MaResult[string] =
return err("multiaddress: Unsupported protocol '" & $header & "'")
if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.bufferToString):
return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder")
return err("multiaddress: Missing protocol '" & $(proto.mcodec) &
"' coder")
if not proto.coder.bufferToString(vb.data, part):
return err("multiaddress: Decoding protocol error")
parts.add($(proto.mcodec))
@ -729,12 +734,14 @@ proc init*(
of None:
raiseAssert "None checked above"
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerID): MaResult[MultiAddress] {.inline.} =
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: PeerID): MaResult[MultiAddress] {.inline.} =
## Initialize MultiAddress object from protocol id ``protocol`` and peer id
## ``value``.
init(mtype, protocol, value.data)
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec,
value: int): MaResult[MultiAddress] =
## Initialize MultiAddress object from protocol id ``protocol`` and integer
## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``,
## ``dccp`` and ``sctp`` MultiAddresses.
@ -759,7 +766,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} =
if mc != InvalidMultiCodec:
result = CodeAddresses.getOrDefault(mc)
proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress],
value: string): MaResult[MultiAddress] =
## Initialize MultiAddress object from string representation ``value``.
var parts = value.trimRight('/').split('/')
if len(parts[0]) != 0:
@ -776,7 +784,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
else:
if proto.kind in {Fixed, Length, Path}:
if isNil(proto.coder.stringToBuffer):
return err("multiaddress: Missing protocol '" & part & "' transcoder")
return err("multiaddress: Missing protocol '" &
part & "' transcoder")
if offset + 1 >= len(parts):
return err("multiaddress: Missing protocol '" & part & "' argument")
@ -785,14 +794,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.write(proto.mcodec)
let res = proto.coder.stringToBuffer(parts[offset + 1], res.data)
if not res:
return err("multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`")
return err("multiaddress: Error encoding `" & part & "/" &
parts[offset + 1] & "`")
offset += 2
elif proto.kind == Path:
var path = "/" & (parts[(offset + 1)..^1].join("/"))
res.data.write(proto.mcodec)
if not proto.coder.stringToBuffer(path, res.data):
return err("multiaddress: Error encoding `" & part & "/" & path & "`")
return err("multiaddress: Error encoding `" & part & "/" &
path & "`")
break
elif proto.kind == Marker:
@ -801,8 +812,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress]
res.data.finish()
ok(res)
proc init*(mtype: typedesc[MultiAddress], data: openarray[byte]): MaResult[MultiAddress] =
proc init*(mtype: typedesc[MultiAddress],
data: openarray[byte]): MaResult[MultiAddress] =
## Initialize MultiAddress with array of bytes ``data``.
if len(data) == 0:
err("multiaddress: Address could not be empty!")
@ -836,10 +847,12 @@ proc init*(mtype: typedesc[MultiAddress],
var data = initVBuffer()
data.write(familyProto.mcodec)
var written = familyProto.coder.stringToBuffer($address, data)
doAssert written, "Merely writing a string to a buffer should always be possible"
doAssert written,
"Merely writing a string to a buffer should always be possible"
data.write(protoProto.mcodec)
written = protoProto.coder.stringToBuffer($port, data)
doAssert written, "Merely writing a string to a buffer should always be possible"
doAssert written,
"Merely writing a string to a buffer should always be possible"
data.finish()
MultiAddress(data: data)
@ -890,14 +903,16 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] =
else:
ok()
proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [Defect, ResultError[string]].} =
proc `&`*(m1, m2: MultiAddress): MultiAddress {.
raises: [Defect, ResultError[string]].} =
## Concatenates two addresses ``m1`` and ``m2``, and returns result.
##
## This procedure performs validation of concatenated result and can raise
## exception on error.
concat(m1, m2).tryGet()
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [Defect, ResultError[string]].} =
proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.
raises: [Defect, ResultError[string]].} =
## Concatenates two addresses ``m1`` and ``m2``.
##
## This procedure performs validation of concatenated result and can raise
@ -1005,3 +1020,36 @@ proc `$`*(pat: MaPattern): string =
result = "(" & sub.join("|") & ")"
elif pat.operator == Eq:
result = $pat.value
proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} =
write(pb, field, value.data.buffer)
proc getField*(pb: var ProtoBuffer, field: int,
value: var MultiAddress): bool {.inline.} =
var buffer: seq[byte]
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
let ma = MultiAddress.init(buffer)
if ma.isOk():
value = ma.get()
true
else:
false
proc getRepeatedField*(pb: var ProtoBuffer, field: int,
value: var seq[MultiAddress]): bool {.inline.} =
var items: seq[seq[byte]]
value.setLen(0)
if not(getRepeatedField(pb, field, items)):
return false
if len(items) == 0:
return true
for item in items:
let ma = MultiAddress.init(item)
if ma.isOk():
value.add(ma.get())
else:
value.setLen(0)
return false

View File

@ -11,7 +11,6 @@ import strutils, tables
import chronos, chronicles, stew/byteutils
import stream/connection,
vbuffer,
errors,
protocols/protocol
logScope:

View File

@ -14,8 +14,6 @@ import types,
nimcrypto/utils,
../../stream/connection,
../../stream/bufferstream,
../../utility,
../../errors,
../../peerinfo
export connection
@ -189,14 +187,11 @@ proc closeRemote*(s: LPChannel) {.async.} =
# stack = getStackTrace()
trace "got EOF, closing channel"
# wait for all data in the buffer to be consumed
while s.len > 0:
await s.dataReadEvent.wait()
s.dataReadEvent.clear()
await s.drainBuffer()
s.isEof = true # set EOF immediately to prevent further reads
await s.close() # close local end
# call to avoid leaks
await procCall BufferStream(s).close() # close parent bufferstream
trace "channel closed on EOF"
@ -227,7 +222,11 @@ method reset*(s: LPChannel) {.base, async, gcsafe.} =
# might be dead already - reset is always
# optimistic
asyncCheck s.resetMessage()
# drain the buffer before closing
await s.drainBuffer()
await procCall BufferStream(s).close()
s.isEof = true
s.closedLocal = true
@ -254,11 +253,11 @@ method close*(s: LPChannel) {.async, gcsafe.} =
if s.atEof: # already closed by remote close parent buffer immediately
await procCall BufferStream(s).close()
except CancelledError as exc:
await s.reset() # reset on timeout
await s.reset()
raise exc
except CatchableError as exc:
trace "exception closing channel"
await s.reset() # reset on timeout
await s.reset()
trace "lpchannel closed local"

View File

@ -13,7 +13,6 @@ import ../muxer,
../../stream/connection,
../../stream/bufferstream,
../../utility,
../../errors,
../../peerinfo,
coder,
types,

View File

@ -10,7 +10,6 @@
import chronos, chronicles
import ../protocols/protocol,
../stream/connection,
../peerinfo,
../errors
logScope:

View File

@ -200,11 +200,12 @@ proc write*(vb: var VBuffer, pid: PeerID) {.inline.} =
## Write PeerID value ``peerid`` to buffer ``vb``.
vb.writeSeq(pid.data)
proc initProtoField*(index: int, pid: PeerID): ProtoField =
proc initProtoField*(index: int, pid: PeerID): ProtoField {.deprecated.} =
## Initialize ProtoField with PeerID ``value``.
result = initProtoField(index, pid.data)
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int {.
deprecated.} =
## Read ``PeerID`` from ProtoBuf's message and validate it.
var pid: PeerID
result = getLengthValue(data, field, pid.data)
@ -213,3 +214,21 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int =
result = -1
else:
value = pid
proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) =
## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding.
write(pb, field, pid.data)
proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool =
## Read ``PeerID`` from ProtoBuf's message and validate it
var buffer: seq[byte]
var peerId: PeerID
if not(getField(pb, field, buffer)):
return false
if len(buffer) == 0:
return false
if peerId.init(buffer):
pid = peerId
true
else:
false

View File

@ -11,7 +11,7 @@
{.push raises: [Defect].}
import ../varint
import ../varint, stew/endians2
const
MaxMessageSize* = 1'u shl 22
@ -32,10 +32,14 @@ type
offset*: int
length*: int
ProtoHeader* = object
wire*: ProtoFieldKind
index*: uint64
ProtoField* = object
## Protobuf's message field representation object
index: int
case kind: ProtoFieldKind
index*: int
case kind*: ProtoFieldKind
of Varint:
vint*: uint64
of Fixed64:
@ -47,13 +51,35 @@ type
of StartGroup, EndGroup:
discard
template protoHeader*(index: int, wire: ProtoFieldKind): uint =
## Get protobuf's field header integer for ``index`` and ``wire``.
((uint(index) shl 3) or cast[uint](wire))
ProtoResult {.pure.} = enum
VarintDecodeError,
MessageIncompleteError,
BufferOverflowError,
MessageSizeTooBigError,
NoError
template protoHeader*(field: ProtoField): uint =
ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 |
hint | hint32 | hint64 | float32 | float64
const
SupportedWireTypes* = {
int(ProtoFieldKind.Varint),
int(ProtoFieldKind.Fixed64),
int(ProtoFieldKind.Length),
int(ProtoFieldKind.Fixed32)
}
template checkFieldNumber*(i: int) =
doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999),
"Incorrect or reserved field number")
template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 =
## Get protobuf's field header integer for ``index`` and ``wire``.
((uint64(index) shl 3) or uint64(wire))
template getProtoHeader*(field: ProtoField): uint64 =
## Get protobuf's field header integer for ``field``.
((uint(field.index) shl 3) or cast[uint](field.kind))
((uint64(field.index) shl 3) or uint64(field.kind))
template toOpenArray*(pb: ProtoBuffer): untyped =
toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1)
@ -72,20 +98,20 @@ template getLen*(pb: ProtoBuffer): int =
proc vsizeof*(field: ProtoField): int {.inline.} =
## Returns number of bytes required to store protobuf's field ``field``.
result = vsizeof(protoHeader(field))
case field.kind
of ProtoFieldKind.Varint:
result += vsizeof(field.vint)
vsizeof(getProtoHeader(field)) + vsizeof(field.vint)
of ProtoFieldKind.Fixed64:
result += sizeof(field.vfloat64)
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat64)
of ProtoFieldKind.Fixed32:
result += sizeof(field.vfloat32)
vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32)
of ProtoFieldKind.Length:
result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer)
vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) +
len(field.vbuffer)
else:
discard
0
proc initProtoField*(index: int, value: SomeVarint): ProtoField =
proc initProtoField*(index: int, value: SomeVarint): ProtoField {.deprecated.} =
## Initialize ProtoField with integer value.
result = ProtoField(kind: Varint, index: index)
when type(value) is uint64:
@ -93,26 +119,28 @@ proc initProtoField*(index: int, value: SomeVarint): ProtoField =
else:
result.vint = cast[uint64](value)
proc initProtoField*(index: int, value: bool): ProtoField =
proc initProtoField*(index: int, value: bool): ProtoField {.deprecated.} =
## Initialize ProtoField with integer value.
result = ProtoField(kind: Varint, index: index)
result.vint = byte(value)
proc initProtoField*(index: int, value: openarray[byte]): ProtoField =
proc initProtoField*(index: int,
value: openarray[byte]): ProtoField {.deprecated.} =
## Initialize ProtoField with bytes array.
result = ProtoField(kind: Length, index: index)
if len(value) > 0:
result.vbuffer = newSeq[byte](len(value))
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
proc initProtoField*(index: int, value: string): ProtoField =
proc initProtoField*(index: int, value: string): ProtoField {.deprecated.} =
## Initialize ProtoField with string.
result = ProtoField(kind: Length, index: index)
if len(value) > 0:
result.vbuffer = newSeq[byte](len(value))
copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value))
proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} =
proc initProtoField*(index: int,
value: ProtoBuffer): ProtoField {.deprecated, inline.} =
## Initialize ProtoField with nested message stored in ``value``.
##
## Note: This procedure performs shallow copy of ``value`` sequence.
@ -127,6 +155,13 @@ proc initProtoBuffer*(data: seq[byte], offset = 0,
result.offset = offset
result.options = options
proc initProtoBuffer*(data: openarray[byte], offset = 0,
options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with copy of ``data``.
result.buffer = @data
result.offset = offset
result.options = options
proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
## Initialize ProtoBuffer with new sequence of capacity ``cap``.
result.buffer = newSeqOfCap[byte](128)
@ -138,16 +173,134 @@ proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer =
result.offset = 10
elif {WithUint32LeLength, WithUint32BeLength} * options != {}:
# Our buffer will start from position 4, so we can store length of buffer
# in [0, 9].
# in [0, 3].
result.buffer.setLen(4)
result.offset = 4
proc write*(pb: var ProtoBuffer, field: ProtoField) =
proc write*[T: ProtoScalar](pb: var ProtoBuffer,
field: int, value: T) =
checkFieldNumber(field)
var length = 0
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) +
vsizeof(value)
let header = ProtoFieldKind.Varint
elif T is float32:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) +
sizeof(T)
let header = ProtoFieldKind.Fixed32
elif T is float64:
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) +
sizeof(T)
let header = ProtoFieldKind.Fixed64
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, header))
doAssert(hres.isOk())
pb.offset += length
when (T is uint64) or (T is uint32) or (T is uint):
let vres = PB.putUVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
let vres = putSVarint(pb.toOpenArray(), length, value)
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
let u32 = cast[uint32](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
let u64 = cast[uint64](value)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int,
value: openarray[T]) =
checkFieldNumber(field)
var length = 0
let dlength =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
var res = 0
for item in value:
res += vsizeof(item)
res
elif (T is float32) or (T is float64):
len(value) * sizeof(T)
let header = getProtoHeader(field, ProtoFieldKind.Length)
let flength = vsizeof(header) + vsizeof(uint64(dlength)) + dlength
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length, header)
doAssert(hres.isOk())
pb.offset += length
length = 0
let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(dlength))
doAssert(lres.isOk())
pb.offset += length
for item in value:
when (T is uint64) or (T is uint32) or (T is uint):
length = 0
let vres = PB.putUVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
pb.offset += length
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
length = 0
let vres = PB.putSVarint(pb.toOpenArray(), length, item)
doAssert(vres.isOk())
pb.offset += length
elif T is float32:
doAssert(pb.isEnough(sizeof(T)))
let u32 = cast[uint32](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE()
pb.offset += sizeof(T)
elif T is float64:
doAssert(pb.isEnough(sizeof(T)))
let u64 = cast[uint64](item)
pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE()
pb.offset += sizeof(T)
proc write*[T: byte|char](pb: var ProtoBuffer, field: int,
value: openarray[T]) =
checkFieldNumber(field)
var length = 0
let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) +
vsizeof(uint64(len(value))) + len(value)
pb.buffer.setLen(len(pb.buffer) + flength)
let hres = PB.putUVarint(pb.toOpenArray(), length,
getProtoHeader(field, ProtoFieldKind.Length))
doAssert(hres.isOk())
pb.offset += length
let lres = PB.putUVarint(pb.toOpenArray(), length,
uint64(len(value)))
doAssert(lres.isOk())
pb.offset += length
if len(value) > 0:
doAssert(pb.isEnough(len(value)))
copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value))
pb.offset += len(value)
proc write*(pb: var ProtoBuffer, field: int, value: ProtoBuffer) {.inline.} =
## Encode Protobuf's sub-message ``value`` and store it to protobuf's buffer
## ``pb`` with field number ``field``.
write(pb, field, value.buffer)
proc write*(pb: var ProtoBuffer, field: ProtoField) {.deprecated.} =
## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``.
var length = 0
var res: VarintResult[void]
pb.buffer.setLen(len(pb.buffer) + vsizeof(field))
res = PB.putUVarint(pb.toOpenArray(), length, protoHeader(field))
res = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field))
doAssert(res.isOk())
pb.offset += length
case field.kind
@ -199,31 +352,440 @@ proc finish*(pb: var ProtoBuffer) =
pb.offset = pos
elif WithUint32BeLength in pb.options:
let size = uint(len(pb.buffer) - 4)
pb.buffer[0] = byte((size shr 24) and 0xFF'u)
pb.buffer[1] = byte((size shr 16) and 0xFF'u)
pb.buffer[2] = byte((size shr 8) and 0xFF'u)
pb.buffer[3] = byte(size and 0xFF'u)
pb.buffer[0 ..< 4] = toBytesBE(uint32(size))
pb.offset = 4
elif WithUint32LeLength in pb.options:
let size = uint(len(pb.buffer) - 4)
pb.buffer[0] = byte(size and 0xFF'u)
pb.buffer[1] = byte((size shr 8) and 0xFF'u)
pb.buffer[2] = byte((size shr 16) and 0xFF'u)
pb.buffer[3] = byte((size shr 24) and 0xFF'u)
pb.buffer[0 ..< 4] = toBytesLE(uint32(size))
pb.offset = 4
else:
pb.offset = 0
proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool =
var length = 0
var hdr = 0'u64
if PB.getUVarint(data.toOpenArray(), length, hdr).isOk():
let index = uint64(hdr shr 3)
let wire = hdr and 0x07
if wire in SupportedWireTypes:
data.offset += length
header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire))
true
else:
false
else:
false
proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool =
case header.wire
of ProtoFieldKind.Varint:
var length = 0
var value = 0'u64
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
true
else:
false
of ProtoFieldKind.Fixed32:
if data.isEnough(sizeof(uint32)):
data.offset += sizeof(uint32)
true
else:
false
of ProtoFieldKind.Fixed64:
if data.isEnough(sizeof(uint64)):
data.offset += sizeof(uint64)
true
else:
false
of ProtoFieldKind.Length:
var length = 0
var bsize = 0'u64
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
data.offset += int(bsize)
true
else:
false
else:
false
else:
false
of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup:
false
proc getValue[T: ProtoScalar](data: var ProtoBuffer,
header: ProtoHeader,
outval: var T): ProtoResult =
when (T is uint64) or (T is uint32) or (T is uint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
var value = T(0)
if PB.getUVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
outval = value
ProtoResult.NoError
else:
ProtoResult.VarintDecodeError
elif (T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
doAssert(header.wire == ProtoFieldKind.Varint)
var length = 0
var value = T(0)
if getSVarint(data.toOpenArray(), length, value).isOk():
data.offset += length
outval = value
ProtoResult.NoError
else:
ProtoResult.VarintDecodeError
elif T is float32:
doAssert(header.wire == ProtoFieldKind.Fixed32)
if data.isEnough(sizeof(float32)):
outval = cast[float32](fromBytesLE(uint32, data.toOpenArray()))
data.offset += sizeof(float32)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
elif T is float64:
doAssert(header.wire == ProtoFieldKind.Fixed64)
if data.isEnough(sizeof(float64)):
outval = cast[float64](fromBytesLE(uint64, data.toOpenArray()))
data.offset += sizeof(float64)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var openarray[T],
outLength: var int): ProtoResult =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
outLength = 0
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
outLength = int(bsize)
if len(outBytes) >= int(bsize):
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ProtoResult.NoError
else:
# Buffer overflow should not be critical failure
data.offset += int(bsize)
ProtoResult.BufferOverflowError
else:
ProtoResult.MessageIncompleteError
else:
ProtoResult.MessageSizeTooBigError
else:
ProtoResult.VarintDecodeError
proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader,
outBytes: var T): ProtoResult =
doAssert(header.wire == ProtoFieldKind.Length)
var length = 0
var bsize = 0'u64
outBytes.setLen(0)
if PB.getUVarint(data.toOpenArray(), length, bsize).isOk():
data.offset += length
if bsize <= uint64(MaxMessageSize):
if data.isEnough(int(bsize)):
outBytes.setLen(bsize)
if bsize > 0'u64:
copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize))
data.offset += int(bsize)
ProtoResult.NoError
else:
ProtoResult.MessageIncompleteError
else:
ProtoResult.MessageSizeTooBigError
else:
ProtoResult.VarintDecodeError
proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var T): bool =
checkFieldNumber(field)
var value: T
var res = false
var pb = data
output = T(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output = T(0)
return false
let wireCheck =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
header.wire == ProtoFieldKind.Varint
elif T is float32:
header.wire == ProtoFieldKind.Fixed32
elif T is float64:
header.wire == ProtoFieldKind.Fixed64
if header.index == uint64(field):
if wireCheck:
let r = getValue(pb, header, value)
case r
of ProtoResult.NoError:
res = true
output = value
else:
return false
else:
# We are ignoring wire types different from what we expect, because it
# is how `protoc` is working.
if not(skipValue(pb, header)):
output = T(0)
return false
else:
if not(skipValue(pb, header)):
output = T(0)
return false
res
proc getField*[T: byte|char](data: ProtoBuffer, field: int,
output: var openarray[T],
outlen: var int): bool =
checkFieldNumber(field)
var pb = data
var res = false
outlen = 0
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output, outlen)
case r
of ProtoResult.NoError:
res = true
of ProtoResult.BufferOverflowError:
# Buffer overflow error is not critical error, we still can get
# field values with proper size.
discard
else:
if len(output) > 0:
zeroMem(addr output[0], len(output))
return false
else:
# We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working.
if not(skipValue(pb, header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
else:
if not(skipValue(pb, header)):
if len(output) > 0:
zeroMem(addr output[0], len(output))
outlen = 0
return false
res
proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var T): bool =
checkFieldNumber(field)
var res = false
var pb = data
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
let r = getValue(pb, header, output)
case r
of ProtoResult.NoError:
res = true
of ProtoResult.BufferOverflowError:
# Buffer overflow error is not critical error, we still can get
# field values with proper size.
discard
else:
output.setLen(0)
return false
else:
# We are ignoring wire types different from ProtoFieldKind.Length,
# because it is how `protoc` is working.
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
res
proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {.
inline.} =
var buffer: seq[byte]
if pb.getField(field, buffer):
output = initProtoBuffer(buffer)
true
else:
false
proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
var item: T
let r = getValue(pb, header, item)
case r
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer,
field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32,
ProtoFieldKind.Fixed64}:
var item: T
let r = getValue(pb, header, item)
case r
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int,
output: var seq[T]): bool =
checkFieldNumber(field)
var pb = data
output.setLen(0)
while not(pb.isEmpty()):
var header: ProtoHeader
if not(pb.getHeader(header)):
output.setLen(0)
return false
if header.index == uint64(field):
if header.wire == ProtoFieldKind.Length:
var arritem: seq[byte]
let rarr = getValue(pb, header, arritem)
case rarr
of ProtoResult.NoError:
var pbarr = initProtoBuffer(arritem)
let itemHeader =
when (T is uint64) or (T is uint32) or (T is uint) or
(T is zint64) or (T is zint32) or (T is zint) or
(T is hint64) or (T is hint32) or (T is hint):
ProtoHeader(wire: ProtoFieldKind.Varint)
elif T is float32:
ProtoHeader(wire: ProtoFieldKind.Fixed32)
elif T is float64:
ProtoHeader(wire: ProtoFieldKind.Fixed64)
while not(pbarr.isEmpty()):
var item: T
let res = getValue(pbarr, itemHeader, item)
case res
of ProtoResult.NoError:
output.add(item)
else:
output.setLen(0)
return false
else:
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
else:
if not(skipValue(pb, header)):
output.setLen(0)
return false
if len(output) > 0:
true
else:
false
proc getVarintValue*(data: var ProtoBuffer, field: int,
value: var SomeVarint): int =
value: var SomeVarint): int {.deprecated.} =
## Get value of `Varint` type.
var length = 0
var header = 0'u64
var soffset = data.offset
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, header).isOk():
data.offset += length
if header == protoHeader(field, Varint):
if header == getProtoHeader(field, Varint):
if not data.isEmpty():
when type(value) is int32 or type(value) is int64 or type(value) is int:
let res = getSVarint(data.toOpenArray(), length, value)
@ -237,7 +799,7 @@ proc getVarintValue*(data: var ProtoBuffer, field: int,
data.offset = soffset
proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
buffer: var T): int =
buffer: var T): int {.deprecated.} =
## Get value of `Length` type.
var length = 0
var header = 0'u64
@ -245,10 +807,12 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
var soffset = data.offset
result = -1
buffer.setLen(0)
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk():
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, header).isOk():
data.offset += length
if header == protoHeader(field, Length):
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, ssize).isOk():
if header == getProtoHeader(field, Length):
if not data.isEmpty() and PB.getUVarint(data.toOpenArray(),
length, ssize).isOk():
data.offset += length
if ssize <= MaxMessageSize and data.isEnough(int(ssize)):
buffer.setLen(ssize)
@ -262,16 +826,16 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int,
data.offset = soffset
proc getBytes*(data: var ProtoBuffer, field: int,
buffer: var seq[byte]): int {.inline.} =
buffer: var seq[byte]): int {.deprecated, inline.} =
## Get value of `Length` type as bytes.
result = getLengthValue(data, field, buffer)
proc getString*(data: var ProtoBuffer, field: int,
buffer: var string): int {.inline.} =
buffer: var string): int {.deprecated, inline.} =
## Get value of `Length` type as string.
result = getLengthValue(data, field, buffer)
proc enterSubmessage*(pb: var ProtoBuffer): int =
proc enterSubmessage*(pb: var ProtoBuffer): int {.deprecated.} =
## Processes protobuf's sub-message and adjust internal offset to enter
## inside of sub-message. Returns field index of sub-message field or
## ``0`` on error.
@ -280,10 +844,12 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
var msize = 0'u64
var soffset = pb.offset
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, header).isOk():
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
length, header).isOk():
pb.offset += length
if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length):
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, msize).isOk():
if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(),
length, msize).isOk():
pb.offset += length
if msize <= MaxMessageSize and pb.isEnough(int(msize)):
pb.length = int(msize)
@ -292,7 +858,7 @@ proc enterSubmessage*(pb: var ProtoBuffer): int =
# Restore offset on error
pb.offset = soffset
proc skipSubmessage*(pb: var ProtoBuffer) =
proc skipSubmessage*(pb: var ProtoBuffer) {.deprecated.} =
## Skip current protobuf's sub-message and adjust internal offset to the
## end of sub-message.
doAssert(pb.length != 0)

View File

@ -47,61 +47,49 @@ type
proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer =
result = initProtoBuffer()
result.write(initProtoField(1, peerInfo.publicKey.get().getBytes().tryGet()))
result.write(1, peerInfo.publicKey.get().getBytes().tryGet())
for ma in peerInfo.addrs:
result.write(initProtoField(2, ma.data.buffer))
result.write(2, ma.data.buffer)
for proto in peerInfo.protocols:
result.write(initProtoField(3, proto))
result.write(3, proto)
result.write(initProtoField(4, observedAddr.data.buffer))
result.write(4, observedAddr.data.buffer)
let protoVersion = ProtoVersion
result.write(initProtoField(5, protoVersion))
result.write(5, protoVersion)
let agentVersion = AgentVersion
result.write(initProtoField(6, agentVersion))
result.write(6, agentVersion)
result.finish()
proc decodeMsg*(buf: seq[byte]): IdentifyInfo =
var pb = initProtoBuffer(buf)
result.pubKey = none(PublicKey)
var pubKey: PublicKey
if pb.getValue(1, pubKey) > 0:
if pb.getField(1, pubKey):
trace "read public key from message", pubKey = ($pubKey).shortLog
result.pubKey = some(pubKey)
result.addrs = newSeq[MultiAddress]()
var address = newSeq[byte]()
while pb.getBytes(2, address) > 0:
if len(address) != 0:
var copyaddr = address
var ma = MultiAddress.init(copyaddr).tryGet()
result.addrs.add(ma)
trace "read address bytes from message", address = ma
address.setLen(0)
if pb.getRepeatedField(2, result.addrs):
trace "read addresses from message", addresses = result.addrs
var proto = ""
while pb.getString(3, proto) > 0:
trace "read proto from message", proto = proto
result.protos.add(proto)
proto = ""
if pb.getRepeatedField(3, result.protos):
trace "read protos from message", protocols = result.protos
var observableAddr = newSeq[byte]()
if pb.getBytes(4, observableAddr) > 0: # attempt to read the observed addr
var ma = MultiAddress.init(observableAddr).tryGet()
trace "read observedAddr from message", address = ma
result.observedAddr = some(ma)
var observableAddr: MultiAddress
if pb.getField(4, observableAddr):
trace "read observableAddr from message", address = observableAddr
result.observedAddr = some(observableAddr)
var protoVersion = ""
if pb.getString(5, protoVersion) > 0:
if pb.getField(5, protoVersion):
trace "read protoVersion from message", protoVersion = protoVersion
result.protoVersion = some(protoVersion)
var agentVersion = ""
if pb.getString(6, agentVersion) > 0:
if pb.getField(6, agentVersion):
trace "read agentVersion from message", agentVersion = agentVersion
result.agentVersion = some(agentVersion)

View File

@ -15,9 +15,7 @@ import pubsub,
rpc/[messages, message],
../../stream/connection,
../../peerid,
../../peerinfo,
../../utility,
../../errors
../../peerinfo
logScope:
topics = "floodsub"
@ -26,7 +24,7 @@ const FloodSubCodec* = "/floodsub/1.0.0"
type
FloodSub* = ref object of PubSub
floodsub*: Table[string, HashSet[string]] # topic to remote peer map
floodsub*: PeerTable # topic to remote peer map
seen*: TimedCache[string] # list of messages forwarded to peers
method subscribeTopic*(f: FloodSub,
@ -35,23 +33,28 @@ method subscribeTopic*(f: FloodSub,
peerId: string) {.gcsafe, async.} =
await procCall PubSub(f).subscribeTopic(topic, subscribe, peerId)
let peer = f.peers.getOrDefault(peerId)
if peer == nil:
debug "subscribeTopic on a nil peer!"
return
if topic notin f.floodsub:
f.floodsub[topic] = initHashSet[string]()
f.floodsub[topic] = initHashSet[PubSubPeer]()
if subscribe:
trace "adding subscription for topic", peer = peerId, name = topic
trace "adding subscription for topic", peer = peer.id, name = topic
# subscribe the peer to the topic
f.floodsub[topic].incl(peerId)
f.floodsub[topic].incl(peer)
else:
trace "removing subscription for topic", peer = peerId, name = topic
trace "removing subscription for topic", peer = peer.id, name = topic
# unsubscribe the peer from the topic
f.floodsub[topic].excl(peerId)
f.floodsub[topic].excl(peer)
method handleDisconnect*(f: FloodSub, peer: PubSubPeer) =
## handle peer disconnects
for t in toSeq(f.floodsub.keys):
if t in f.floodsub:
f.floodsub[t].excl(peer.id)
f.floodsub[t].excl(peer)
procCall PubSub(f).handleDisconnect(peer)
@ -62,7 +65,7 @@ method rpcHandler*(f: FloodSub,
for m in rpcMsgs: # for all RPC messages
if m.messages.len > 0: # if there are any messages
var toSendPeers: HashSet[string] = initHashSet[string]()
var toSendPeers = initHashSet[PubSubPeer]()
for msg in m.messages: # for every message
let msgId = f.msgIdProvider(msg)
logScope: msgId
@ -138,7 +141,8 @@ method publish*(f: FloodSub,
let (published, failed) = await f.sendHelper(f.floodsub.getOrDefault(topic), @[msg])
for p in failed:
let peer = f.peers.getOrDefault(p)
f.handleDisconnect(peer) # cleanup failed peers
if not isNil(peer):
f.handleDisconnect(peer) # cleanup failed peers
libp2p_pubsub_messages_published.inc(labelValues = [topic])
@ -157,6 +161,6 @@ method initPubSub*(f: FloodSub) =
procCall PubSub(f).initPubSub()
f.peers = initTable[string, PubSubPeer]()
f.topics = initTable[string, Topic]()
f.floodsub = initTable[string, HashSet[string]]()
f.floodsub = initTable[string, HashSet[PubSubPeer]]()
f.seen = newTimedCache[string](2.minutes)
f.init()

View File

@ -62,11 +62,11 @@ type
GossipSub* = ref object of FloodSub
parameters*: GossipSubParams
mesh*: Table[string, HashSet[string]] # meshes - topic to peer
fanout*: Table[string, HashSet[string]] # fanout - topic to peer
gossipsub*: Table[string, HashSet[string]] # topic to peer map of all gossipsub peers
explicit*: Table[string, HashSet[string]] # # topic to peer map of all explicit peers
explicitPeers*: HashSet[string] # explicit (always connected/forward) peers
mesh*: PeerTable # peers that we send messages to when we are subscribed to the topic
fanout*: PeerTable # peers that we send messages to when we're not subscribed to the topic
gossipsub*: PeerTable # peers that are subscribed to a topic
explicit*: PeerTable # directpeers that we keep alive explicitly
explicitPeers*: HashSet[string] # explicit (always connected/forward) peers
lastFanoutPubSub*: Table[string, Moment] # last publish time for fanout topics
gossip*: Table[string, seq[ControlIHave]] # pending gossip
control*: Table[string, ControlMessage] # pending control messages
@ -97,6 +97,28 @@ proc init*(_: type[GossipSubParams]): GossipSubParams =
publishThreshold: 1.0,
)
func addPeer(table: var PeerTable, topic: string, peer: PubSubPeer): bool =
# returns true if the peer was added, false if it was already in the collection
not table.mgetOrPut(topic, initHashSet[PubSubPeer]()).containsOrIncl(peer)
func removePeer(table: var PeerTable, topic: string, peer: PubSubPeer) =
table.withValue(topic, peers):
peers[].excl(peer)
if peers[].len == 0:
table.del(topic)
func hasPeer(table: PeerTable, topic: string, peer: PubSubPeer): bool =
(topic in table) and (peer in table[topic])
func peers(table: PeerTable, topic: string): int =
if topic in table:
table[topic].len
else:
0
func getPeers(table: Table[string, HashSet[string]], topic: string): HashSet[string] =
table.getOrDefault(topic, initHashSet[string]())
method init*(g: GossipSub) =
proc handler(conn: Connection, proto: string) {.async.} =
## main protocol handler that gets triggered on every
@ -116,141 +138,104 @@ method init*(g: GossipSub) =
proc replenishFanout(g: GossipSub, topic: string) =
## get fanout peers for a topic
trace "about to replenish fanout"
if topic notin g.fanout:
g.fanout[topic] = initHashSet[string]()
if g.fanout.getOrDefault(topic).len < GossipSubDLo:
trace "replenishing fanout", peers = g.fanout.getOrDefault(topic).len
if topic in toSeq(g.gossipsub.keys):
for p in g.gossipsub.getOrDefault(topic):
if not g.fanout[topic].containsOrIncl(p):
if g.fanout.getOrDefault(topic).len == GossipSubD:
if g.fanout.peers(topic) < GossipSubDLo:
trace "replenishing fanout", peers = g.fanout.peers(topic)
if topic in g.gossipsub:
for peer in g.gossipsub[topic]:
if g.fanout.addPeer(topic, peer):
if g.fanout.peers(topic) == GossipSubD:
break
libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.getOrDefault(topic).len.int64,
labelValues = [topic])
.set(g.fanout.peers(topic).int64, labelValues = [topic])
trace "fanout replenished with peers", peers = g.fanout.getOrDefault(topic).len
template moveToMeshHelper(g: GossipSub,
topic: string,
table: Table[string, HashSet[string]]) =
## move peers from `table` into `mesh`
##
var peerIds = toSeq(table.getOrDefault(topic))
logScope:
topic = topic
meshPeers = g.mesh.getOrDefault(topic).len
peers = peerIds.len
shuffle(peerIds)
for id in peerIds:
if g.mesh.getOrDefault(topic).len > GossipSubD:
break
trace "gathering peers for mesh"
if topic notin table:
continue
trace "getting peers", topic,
peers = peerIds.len
table[topic].excl(id) # always exclude
if id in g.mesh[topic]:
continue # we already have this peer in the mesh, try again
if id in g.peers:
let p = g.peers[id]
if p.connected:
# send a graft message to the peer
await p.sendGraft(@[topic])
g.mesh[topic].incl(id)
trace "got peer", peer = id
trace "fanout replenished with peers", peers = g.fanout.peers(topic)
proc rebalanceMesh(g: GossipSub, topic: string) {.async.} =
try:
trace "about to rebalance mesh"
# create a mesh topic that we're subscribing to
if topic notin g.mesh:
g.mesh[topic] = initHashSet[string]()
trace "about to rebalance mesh"
# create a mesh topic that we're subscribing to
if g.mesh.getOrDefault(topic).len < GossipSubDlo:
trace "replenishing mesh", topic
# replenish the mesh if we're below GossipSubDlo
var
grafts, prunes: seq[PubSubPeer]
# move fanout nodes first
g.moveToMeshHelper(topic, g.fanout)
if g.mesh.peers(topic) < GossipSubDlo:
trace "replenishing mesh", topic, peers = g.mesh.peers(topic)
# replenish the mesh if we're below GossipSubDlo
var newPeers = toSeq(
g.gossipsub.getOrDefault(topic, initHashSet[PubSubPeer]()) -
g.mesh.getOrDefault(topic, initHashSet[PubSubPeer]())
)
# move gossipsub nodes second
g.moveToMeshHelper(topic, g.gossipsub)
logScope:
topic = topic
meshPeers = g.mesh.peers(topic)
newPeers = newPeers.len
if g.mesh.getOrDefault(topic).len > GossipSubDhi:
# prune peers if we've gone over
shuffle(newPeers)
# ATTN possible perf bottleneck here... score is a "red" function
# and we call a lot of Table[] etc etc
trace "getting peers", topic, peers = newPeers.len
# gather peers
var peers = toSeq(g.mesh[topic])
# sort peers by score
peers.sort(proc (x, y: string): int =
let
peerx = g.peers[x].score()
peery = g.peers[y].score()
if peerx < peery: -1
elif peerx == peery: 0
else: 1)
while g.mesh[topic].len > GossipSubD:
trace "pruning peers", peers = g.mesh[topic].len
for peer in newPeers:
# send a graft message to the peer
grafts.add peer
discard g.mesh.addPeer(topic, peer)
trace "got peer", peer = peer.id
# pop a low score peer
let
id = peers.pop()
g.mesh[topic].excl(id)
if g.mesh.peers(topic) > GossipSubDhi:
# prune peers if we've gone over
# gather peers
var mesh = toSeq(g.mesh[topic])
# sort peers by score
mesh.sort(proc (x, y: PubSubPeer): int =
let
peerx = x.score()
peery = y.score()
if peerx < peery: -1
elif peerx == peery: 0
else: 1)
# send a prune message to the peer
let
p = g.peers[id]
# TODO send a set of other peers where the pruned peer can connect to reform its mesh
await p.sendPrune(@[topic])
trace "about to prune mesh", mesh = mesh.len
for peer in mesh:
if g.mesh.peers(topic) <= GossipSubD:
break
libp2p_gossipsub_peers_per_topic_gossipsub
.set(g.gossipsub.getOrDefault(topic).len.int64,
labelValues = [topic])
trace "pruning peers", peers = g.mesh.peers(topic)
# send a graft message to the peer
g.mesh.removePeer(topic, peer)
prunes.add(peer)
libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.getOrDefault(topic).len.int64,
labelValues = [topic])
libp2p_gossipsub_peers_per_topic_gossipsub
.set(g.gossipsub.peers(topic).int64, labelValues = [topic])
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh.getOrDefault(topic).len.int64,
labelValues = [topic])
libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.peers(topic).int64, labelValues = [topic])
trace "mesh balanced, got peers", peers = g.mesh.getOrDefault(topic).len,
topicId = topic
except CancelledError as exc:
raise exc
except CatchableError as exc:
warn "exception occurred re-balancing mesh", exc = exc.msg
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh.peers(topic).int64, labelValues = [topic])
proc dropFanoutPeers(g: GossipSub) {.async.} =
# Send changes to peers after table updates to avoid stale state
for p in grafts:
await p.sendGraft(@[topic])
for p in prunes:
await p.sendPrune(@[topic])
trace "mesh balanced, got peers", peers = g.mesh.peers(topic),
topicId = topic
proc dropFanoutPeers(g: GossipSub) =
# drop peers that we haven't published to in
# GossipSubFanoutTTL seconds
var dropping = newSeq[string]()
for topic, val in g.lastFanoutPubSub:
if Moment.now > val:
dropping.add(topic)
let now = Moment.now()
for topic in toSeq(g.lastFanoutPubSub.keys):
let val = g.lastFanoutPubSub[topic]
if now > val:
g.fanout.del(topic)
g.lastFanoutPubSub.del(topic)
trace "dropping fanout topic", topic
for topic in dropping:
g.lastFanoutPubSub.del(topic)
libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.getOrDefault(topic).len.int64, labelValues = [topic])
.set(g.fanout.peers(topic).int64, labelValues = [topic])
proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} =
## gossip iHave messages to peers
@ -278,22 +263,18 @@ proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} =
trace "topic not in gossip array, skipping", topicID = topic
continue
for id in allPeers:
for peer in allPeers:
if result.len >= GossipSubD:
trace "got gossip peers", peers = result.len
break
if allPeers.len == 0:
trace "no peers for topic, skipping", topicID = topic
break
if id in gossipPeers:
if peer in gossipPeers:
continue
if id notin result:
result[id] = controlMsg
if peer.id notin result:
result[peer.id] = controlMsg
result[id].ihave.add(ihave)
result[peer.id].ihave.add(ihave)
proc heartbeat(g: GossipSub) {.async.} =
while g.heartbeatRunning:
@ -303,7 +284,7 @@ proc heartbeat(g: GossipSub) {.async.} =
for t in toSeq(g.topics.keys):
await g.rebalanceMesh(t)
await g.dropFanoutPeers()
g.dropFanoutPeers()
# replenish known topics to the fanout
for t in toSeq(g.fanout.keys):
@ -322,35 +303,38 @@ proc heartbeat(g: GossipSub) {.async.} =
except CatchableError as exc:
trace "exception ocurred in gossipsub heartbeat", exc = exc.msg
await sleepAsync(1.seconds)
await sleepAsync(GossipSubHeartbeatInterval)
method handleDisconnect*(g: GossipSub, peer: PubSubPeer) =
## handle peer disconnects
procCall FloodSub(g).handleDisconnect(peer)
for t in toSeq(g.gossipsub.keys):
g.gossipsub[t].excl(peer.id)
g.gossipsub.removePeer(t, peer)
libp2p_gossipsub_peers_per_topic_gossipsub
.set(g.gossipsub.getOrDefault(t).len.int64, labelValues = [t])
.set(g.gossipsub.peers(t).int64, labelValues = [t])
for t in toSeq(g.mesh.keys):
g.mesh[t].excl(peer.id)
g.mesh.removePeer(t, peer)
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh[t].len.int64, labelValues = [t])
.set(g.mesh.peers(t).int64, labelValues = [t])
for t in toSeq(g.fanout.keys):
g.fanout[t].excl(peer.id)
g.fanout.removePeer(t, peer)
libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout[t].len.int64, labelValues = [t])
.set(g.fanout.peers(t).int64, labelValues = [t])
if peer.peerInfo.maintain:
g.explicitPeers.excl(peer.id)
for t in toSeq(g.explicit.keys):
g.explicit.removePeer(t, peer)
g.explicitPeers.excl(peer.id)
method subscribePeer*(p: GossipSub,
conn: Connection) =
conn: Connection) =
procCall PubSub(p).subscribePeer(conn)
asyncCheck p.handleConn(conn, GossipSubCodec_11)
@ -358,28 +342,36 @@ method subscribeTopic*(g: GossipSub,
topic: string,
subscribe: bool,
peerId: string) {.gcsafe, async.} =
await procCall PubSub(g).subscribeTopic(topic, subscribe, peerId)
if topic notin g.gossipsub:
g.gossipsub[topic] = initHashSet[string]()
await procCall FloodSub(g).subscribeTopic(topic, subscribe, peerId)
let peer = g.peers.getOrDefault(peerId)
if peer == nil:
debug "subscribeTopic on a nil peer!"
return
if subscribe:
trace "adding subscription for topic", peer = peerId, name = topic
# subscribe remote peer to the topic
g.gossipsub[topic].incl(peerId)
discard g.gossipsub.addPeer(topic, peer)
if peerId in g.explicitPeers:
g.explicit[topic].incl(peerId)
discard g.explicit.addPeer(topic, peer)
else:
trace "removing subscription for topic", peer = peerId, name = topic
# unsubscribe remote peer from the topic
g.gossipsub[topic].excl(peerId)
g.gossipsub.removePeer(topic, peer)
g.mesh.removePeer(topic, peer)
g.fanout.removePeer(topic, peer)
if peerId in g.explicitPeers:
g.explicit[topic].excl(peerId)
g.explicit.removePeer(topic, peer)
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh.peers(topic).int64, labelValues = [topic])
libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.peers(topic).int64, labelValues = [topic])
libp2p_gossipsub_peers_per_topic_gossipsub
.set(g.gossipsub[topic].len.int64, labelValues = [topic])
.set(g.gossipsub.peers(topic).int64, labelValues = [topic])
trace "gossip peers", peers = g.gossipsub[topic].len, topic
trace "gossip peers", peers = g.gossipsub.peers(topic), topic
# also rebalance current topic if we are subbed to
if topic in g.topics:
@ -387,43 +379,46 @@ method subscribeTopic*(g: GossipSub,
proc handleGraft(g: GossipSub,
peer: PubSubPeer,
grafts: seq[ControlGraft],
respControl: var ControlMessage) =
grafts: seq[ControlGraft]): seq[ControlPrune] =
for graft in grafts:
trace "processing graft message", peer = peer.id,
topicID = graft.topicID
let topic = graft.topicID
trace "processing graft message", topic, peer = peer.id
# It is an error to GRAFT on a explicit peer
if peer.peerInfo.maintain:
trace "attempt to graft an explicit peer", peer=peer.id,
topicID=graft.topicID
# and such an attempt should be logged and rejected with a PRUNE
respControl.prune.add(ControlPrune(topicID: graft.topicID))
result.add(ControlPrune(topicID: graft.topicID))
continue
if graft.topicID in g.topics:
if g.mesh.len < GossipSubD:
g.mesh[graft.topicID].incl(peer.id)
# If they send us a graft before they send us a subscribe, what should
# we do? For now, we add them to mesh but don't add them to gossipsub.
if topic in g.topics:
if g.mesh.peers(topic) < GossipSubDHi:
# In the spec, there's no mention of DHi here, but implicitly, a
# peer will be removed from the mesh on next rebalance, so we don't want
# this peer to push someone else out
if g.mesh.addPeer(topic, peer):
g.fanout.removePeer(topic, peer)
else:
trace "Peer already in mesh", topic, peer = peer.id
else:
g.gossipsub[graft.topicID].incl(peer.id)
result.add(ControlPrune(topicID: topic))
else:
respControl.prune.add(ControlPrune(topicID: graft.topicID))
result.add(ControlPrune(topicID: topic))
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh[graft.topicID].len.int64, labelValues = [graft.topicID])
libp2p_gossipsub_peers_per_topic_gossipsub
.set(g.gossipsub[graft.topicID].len.int64, labelValues = [graft.topicID])
.set(g.mesh.peers(topic).int64, labelValues = [topic])
proc handlePrune(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) =
for prune in prunes:
trace "processing prune message", peer = peer.id,
topicID = prune.topicID
if prune.topicID in g.mesh:
g.mesh[prune.topicID].excl(peer.id)
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh[prune.topicID].len.int64, labelValues = [prune.topicID])
g.mesh.removePeer(prune.topicID, peer)
libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh.peers(prune.topicID).int64, labelValues = [prune.topicID])
proc handleIHave(g: GossipSub,
peer: PubSubPeer,
@ -456,7 +451,7 @@ method rpcHandler*(g: GossipSub,
for m in rpcMsgs: # for all RPC messages
if m.messages.len > 0: # if there are any messages
var toSendPeers: HashSet[string]
var toSendPeers: HashSet[PubSubPeer]
for msg in m.messages: # for every message
let msgId = g.msgIdProvider(msg)
logScope: msgId
@ -506,25 +501,24 @@ method rpcHandler*(g: GossipSub,
let (published, failed) = await g.sendHelper(toSendPeers, m.messages)
for p in failed:
let peer = g.peers.getOrDefault(p)
if not(isNil(peer)):
if not isNil(peer):
g.handleDisconnect(peer) # cleanup failed peers
trace "forwared message to peers", peers = published.len
var respControl: ControlMessage
if m.control.isSome:
var control: ControlMessage = m.control.get()
let iWant: ControlIWant = g.handleIHave(peer, control.ihave)
if iWant.messageIDs.len > 0:
respControl.iwant.add(iWant)
let messages: seq[Message] = g.handleIWant(peer, control.iwant)
g.handleGraft(peer, control.graft, respControl)
let control = m.control.get()
g.handlePrune(peer, control.prune)
respControl.iwant.add(g.handleIHave(peer, control.ihave))
respControl.prune.add(g.handleGraft(peer, control.graft))
if respControl.graft.len > 0 or respControl.prune.len > 0 or
respControl.ihave.len > 0 or respControl.iwant.len > 0:
await peer.send(@[RPCMsg(control: some(respControl), messages: messages)])
await peer.send(
@[RPCMsg(control: some(respControl),
messages: g.handleIWant(peer, control.iwant))])
method subscribe*(g: GossipSub,
topic: string,
@ -541,9 +535,9 @@ method unsubscribe*(g: GossipSub,
if topic in g.mesh:
let peers = g.mesh.getOrDefault(topic)
g.mesh.del(topic)
for id in peers:
let p = g.peers[id]
await p.sendPrune(@[topic])
for peer in peers:
await peer.sendPrune(@[topic])
method publish*(g: GossipSub,
topic: string,
@ -554,40 +548,51 @@ method publish*(g: GossipSub,
data = data.shortLog
# directly copy explicit peers
# as we will always publish to those
var peers = g.explicitPeers
var peers = initHashSet[PubSubPeer]()
if topic.len <= 0: # data could be 0/empty
return 0
if topic.len > 0: # data could be 0/empty
if g.parameters.floodPublish:
for id, peer in g.peers:
if topic in peer.topics and
peer.score() >= g.parameters.publishThreshold:
debug "publish: including flood/high score peer", peer = id
peers.incl(id)
if g.parameters.floodPublish:
for id, peer in g.peers:
if topic in peer.topics and
peer.score() >= g.parameters.publishThreshold:
debug "publish: including flood/high score peer", peer = id
peers.incl(peer)
# add always direct peers
peers.incl(g.explicit.getOrDefault(topic))
if topic in g.topics: # if we're subscribed use the mesh
peers = g.mesh.getOrDefault(topic)
peers.incl(g.mesh.getOrDefault(topic))
else: # not subscribed, send to fanout peers
# try optimistically
peers = g.fanout.getOrDefault(topic)
peers.incl(g.fanout.getOrDefault(topic))
if peers.len == 0:
# ok we had nothing.. let's try replenish inline
g.replenishFanout(topic)
peers = g.fanout.getOrDefault(topic)
peers.incl(g.fanout.getOrDefault(topic))
# even if we couldn't publish,
# we still attempted to publish
# on the topic, so it makes sense
# to update the last topic publish
# time
g.lastFanoutPubSub[topic] = Moment.fromNow(GossipSubFanoutTTL)
let
msg = Message.init(g.peerInfo, data, topic, g.sign)
msgId = g.msgIdProvider(msg)
trace "created new message", msg
trace "publishing on topic", name = topic, peers = peers
trace "publishing on topic",
topic, peers = peers.len, msg = msg.shortLog()
if msgId notin g.mcache:
g.mcache.put(msgId, msg)
let (published, failed) = await g.sendHelper(peers, @[msg])
for p in failed:
let peer = g.peers.getOrDefault(p)
g.handleDisconnect(peer) # cleanup failed peers
if not isNil(peer):
g.handleDisconnect(peer) # cleanup failed peers
if published.len > 0:
libp2p_pubsub_messages_published.inc(labelValues = [topic])
@ -625,9 +630,9 @@ method initPubSub*(g: GossipSub) =
randomize()
g.mcache = newMCache(GossipSubHistoryGossip, GossipSubHistoryLength)
g.mesh = initTable[string, HashSet[string]]() # meshes - topic to peer
g.fanout = initTable[string, HashSet[string]]() # fanout - topic to peer
g.gossipsub = initTable[string, HashSet[string]]()# topic to peer map of all gossipsub peers
g.mesh = initTable[string, HashSet[PubSubPeer]]() # meshes - topic to peer
g.fanout = initTable[string, HashSet[PubSubPeer]]() # fanout - topic to peer
g.gossipsub = initTable[string, HashSet[PubSubPeer]]()# topic to peer map of all gossipsub peers
g.lastFanoutPubSub = initTable[string, Moment]() # last publish time for fanout topics
g.gossip = initTable[string, seq[ControlIHave]]() # pending gossip
g.control = initTable[string, ControlMessage]() # pending control messages

View File

@ -31,6 +31,8 @@ declareCounter(libp2p_pubsub_validation_failure, "pubsub failed validated messag
declarePublicCounter(libp2p_pubsub_messages_published, "published messages", labels = ["topic"])
type
PeerTable* = Table[string, HashSet[PubSubPeer]]
SendRes = tuple[published: seq[string], failed: seq[string]] # keep private
TopicHandler* = proc(topic: string,
@ -58,19 +60,28 @@ type
cleanupLock: AsyncLock
validators*: Table[string, HashSet[ValidatorHandler]]
observers: ref seq[PubSubObserver] # ref as in smart_ptr
msgIdProvider*: MsgIdProvider # Turn message into message id (not nil)
msgIdProvider*: MsgIdProvider # Turn message into message id (not nil)
proc hasPeerID*(t: PeerTable, topic, peerId: string): bool =
# unefficient but used only in tests!
let peers = t.getOrDefault(topic)
if peers.len == 0:
false
else:
let ps = toSeq(peers)
ps.any do (peer: PubSubPeer) -> bool:
peer.id == peerId
method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} =
## handle peer disconnects
##
if peer.id in p.peers:
trace "deleting peer", peer = peer.id, stack = getStackTrace()
p.peers[peer.id] = nil
if not isNil(peer.peerInfo) and peer.id in p.peers:
trace "deleting peer", peer = peer.id
p.peers.del(peer.id)
trace "peer disconnected", peer = peer.id
# metrics
libp2p_pubsub_peers.set(p.peers.len.int64)
trace "peer disconnected", peer = peer.id
# metrics
libp2p_pubsub_peers.set(p.peers.len.int64)
proc sendSubs*(p: PubSub,
peer: PubSubPeer,
@ -127,19 +138,22 @@ method rpcHandler*(p: PubSub,
trace "about to subscribe to topic", topicId = s.topic
await p.subscribeTopic(s.topic, s.subscribe, peer.id)
proc getPeer(p: PubSub,
peerInfo: PeerInfo,
proto: string): PubSubPeer =
proc getOrCreatePeer(p: PubSub,
peerInfo: PeerInfo,
proto: string): PubSubPeer =
if peerInfo.id in p.peers:
return p.peers[peerInfo.id]
# create new pubsub peer
let peer = newPubSubPeer(peerInfo, proto)
trace "created new pubsub peer", peerId = peer.id, stack = getStackTrace()
trace "created new pubsub peer", peerId = peer.id
p.peers[peer.id] = peer
peer.observers = p.observers
# metrics
libp2p_pubsub_peers.set(p.peers.len.int64)
return peer
method handleConn*(p: PubSub,
@ -165,7 +179,7 @@ method handleConn*(p: PubSub,
# call pubsub rpc handler
await p.rpcHandler(peer, msgs)
let peer = p.getPeer(conn.peerInfo, proto)
let peer = p.getOrCreatePeer(conn.peerInfo, proto)
let topics = toSeq(p.topics.keys)
if topics.len > 0:
await p.sendSubs(peer, topics, true)
@ -184,23 +198,27 @@ method handleConn*(p: PubSub,
method subscribePeer*(p: PubSub, conn: Connection) {.base.} =
if not(isNil(conn)):
let peer = p.getPeer(conn.peerInfo, p.codec)
let peer = p.getOrCreatePeer(conn.peerInfo, p.codec)
trace "subscribing to peer", peerId = conn.peerInfo.id
if not peer.connected:
peer.conn = conn
method unsubscribePeer*(p: PubSub, peerInfo: PeerInfo) {.base, async.} =
let peer = p.getPeer(peerInfo, p.codec)
trace "unsubscribing from peer", peerId = $peerInfo
if not(isNil(peer.conn)):
await peer.conn.close()
if peerInfo.id in p.peers:
let peer = p.peers[peerInfo.id]
p.handleDisconnect(peer)
trace "unsubscribing from peer", peerId = $peerInfo
if not(isNil(peer.conn)):
await peer.conn.close()
proc connected*(p: PubSub, peer: PeerInfo): bool =
let peer = p.getPeer(peer, p.codec)
if not(isNil(peer)):
return peer.connected
p.handleDisconnect(peer)
proc connected*(p: PubSub, peerInfo: PeerInfo): bool =
if peerInfo.id in p.peers:
let peer = p.peers[peerInfo.id]
if not(isNil(peer)):
return peer.connected
method unsubscribe*(p: PubSub,
topics: seq[TopicPair]) {.base, async.} =
@ -212,6 +230,11 @@ method unsubscribe*(p: PubSub,
if h == t.handler:
p.topics[t.topic].handler.del(i)
# make sure we delete the topic if
# no more handlers are left
if p.topics[t.topic].handler.len <= 0:
p.topics.del(t.topic)
method unsubscribe*(p: PubSub,
topic: string,
handler: TopicHandler): Future[void] {.base.} =
@ -242,20 +265,16 @@ method subscribe*(p: PubSub,
libp2p_pubsub_topics.inc()
proc sendHelper*(p: PubSub,
sendPeers: HashSet[string],
sendPeers: HashSet[PubSubPeer],
msgs: seq[Message]): Future[SendRes] {.async.} =
var sent: seq[tuple[id: string, fut: Future[void]]]
for sendPeer in sendPeers:
# avoid sending to self
if sendPeer == p.peerInfo.id:
if sendPeer.peerInfo == p.peerInfo:
continue
let peer = p.peers.getOrDefault(sendPeer)
if isNil(peer):
continue
trace "sending messages to peer", peer = peer.id, msgs
sent.add((id: peer.id, fut: peer.send(@[RPCMsg(messages: msgs)])))
trace "sending messages to peer", peer = sendPeer.id, msgs
sent.add((id: sendPeer.id, fut: sendPeer.send(@[RPCMsg(messages: msgs)])))
var published: seq[string]
var failed: seq[string]

View File

@ -57,6 +57,35 @@ func score*(p: PubSubPeer): float64 =
# TODO
0.0
func hash*(p: PubSubPeer): Hash =
# int is either 32/64, so intptr basically, pubsubpeer is a ref
cast[pointer](p).hash
func `==`*(a, b: PubSubPeer): bool =
# override equiality to support both nil and peerInfo comparisons
# this in the future will allow us to recycle refs
let
aptr = cast[pointer](a)
bptr = cast[pointer](b)
if aptr == nil:
if bptr == nil:
true
else:
false
elif bptr == nil:
false
else:
if a.peerInfo == nil:
if b.peerInfo == nil:
true
else:
false
else:
if b.peerInfo == nil:
false
else:
a.peerInfo.id == b.peerInfo.id
proc id*(p: PubSubPeer): string = p.peerInfo.id
proc connected*(p: PubSubPeer): bool =
@ -176,14 +205,24 @@ proc sendMsg*(p: PubSubPeer,
p.send(@[RPCMsg(messages: @[Message.init(p.peerInfo, data, topic, sign)])])
proc sendGraft*(p: PubSubPeer, topics: seq[string]) {.async.} =
for topic in topics:
trace "sending graft msg to peer", peer = p.id, topicID = topic
await p.send(@[RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))])
try:
for topic in topics:
trace "sending graft msg to peer", peer = p.id, topicID = topic
await p.send(@[RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))])
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "Could not send graft", msg = exc.msg
proc sendPrune*(p: PubSubPeer, topics: seq[string], peers: seq[PeerInfoMsg] = @[], backoff: uint64 = 0) {.async.} =
for topic in topics:
trace "sending prune msg to peer", peer = p.id, topicID = topic
await p.send(@[RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic, peers: peers, backoff: backoff)])))])
proc sendPrune*(p: PubSubPeer, topics: seq[string]) {.async.} =
try:
for topic in topics:
trace "sending prune msg to peer", peer = p.id, topicID = topic
await p.send(@[RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))])
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "Could not send prune", msg = exc.msg
proc `$`*(p: PubSubPeer): string =
p.id

View File

@ -32,9 +32,7 @@ func defaultMsgIdProvider*(m: Message): string =
byteutils.toHex(m.seqno) & m.fromPeer.pretty
proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} =
var buff = initProtoBuffer()
encodeMessage(msg, buff)
p.privateKey.sign(PubSubPrefix & buff.buffer).tryGet().getBytes()
p.privateKey.sign(PubSubPrefix & encodeMessage(msg)).tryGet().getBytes()
proc verify*(m: Message, p: PeerInfo): bool =
if m.signature.len > 0 and m.key.len > 0:
@ -42,14 +40,11 @@ proc verify*(m: Message, p: PeerInfo): bool =
msg.signature = @[]
msg.key = @[]
var buff = initProtoBuffer()
encodeMessage(msg, buff)
var remote: Signature
var key: PublicKey
if remote.init(m.signature) and key.init(m.key):
trace "verifying signature", remoteSignature = remote
result = remote.verify(PubSubPrefix & buff.buffer, key)
result = remote.verify(PubSubPrefix & encodeMessage(msg), key)
if result:
libp2p_pubsub_sig_verify_success.inc()

View File

@ -14,269 +14,247 @@ import messages,
../../../utility,
../../../protobuf/minprotobuf
proc encodeGraft*(graft: ControlGraft, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, graft.topicID))
proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) =
var ipb = initProtoBuffer()
ipb.write(1, graft.topicID)
ipb.finish()
pb.write(field, ipb)
proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} =
trace "decoding graft msg", buffer = pb.buffer.shortLog
while true:
var topic: string
if pb.getString(1, topic) < 0:
break
proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) =
var ipb = initProtoBuffer()
ipb.write(1, prune.topicID)
ipb.finish()
pb.write(field, ipb)
trace "read topic field from graft msg", topicID = topic
result.add(ControlGraft(topicID: topic))
proc encodePeerInfo*(info: PeerInfoMsg, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, info.peerID))
pb.write(initProtoField(2, info.signedPeerRecord))
proc encodePrune*(prune: ControlPrune, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, prune.topicID))
var peers = initProtoBuffer()
for p in prune.peers:
p.encodePeerInfo(peers)
peers.finish()
pb.write(initProtoField(2, peers))
pb.write(initProtoField(3, prune.backoff))
proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} =
trace "decoding prune msg"
while true:
var topic: string
if pb.getString(1, topic) < 0:
break
trace "read topic field from prune msg", topicID = topic
result.add(ControlPrune(topicID: topic))
proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, ihave.topicID))
proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) =
var ipb = initProtoBuffer()
ipb.write(1, ihave.topicID)
for mid in ihave.messageIDs:
pb.write(initProtoField(2, mid))
ipb.write(2, mid)
ipb.finish()
pb.write(field, ipb)
proc decodeIHave*(pb: var ProtoBuffer): seq[ControlIHave] {.gcsafe.} =
trace "decoding ihave msg"
while true:
var control: ControlIHave
if pb.getString(1, control.topicID) < 0:
trace "topic field missing from ihave msg"
break
trace "read topic field", topicID = control.topicID
while true:
var mid: string
if pb.getString(2, mid) < 0:
break
trace "read messageID field", mid = mid
control.messageIDs.add(mid)
result.add(control)
proc encodeIWant*(iwant: ControlIWant, pb: var ProtoBuffer) {.gcsafe.} =
proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) =
var ipb = initProtoBuffer()
for mid in iwant.messageIDs:
pb.write(initProtoField(1, mid))
ipb.write(1, mid)
if len(ipb.buffer) > 0:
ipb.finish()
pb.write(field, ipb)
proc decodeIWant*(pb: var ProtoBuffer): seq[ControlIWant] {.gcsafe.} =
trace "decoding iwant msg"
proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) =
var ipb = initProtoBuffer()
for ihave in control.ihave:
ipb.write(1, ihave)
for iwant in control.iwant:
ipb.write(2, iwant)
for graft in control.graft:
ipb.write(3, graft)
for prune in control.prune:
ipb.write(4, prune)
if len(ipb.buffer) > 0:
ipb.finish()
pb.write(field, ipb)
var control: ControlIWant
while true:
var mid: string
if pb.getString(1, mid) < 0:
break
control.messageIDs.add(mid)
trace "read messageID field", mid = mid
result.add(control)
proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} =
if control.ihave.len > 0:
var ihave = initProtoBuffer()
for h in control.ihave:
h.encodeIHave(ihave)
# write messages to protobuf
ihave.finish()
pb.write(initProtoField(1, ihave))
if control.iwant.len > 0:
var iwant = initProtoBuffer()
for w in control.iwant:
w.encodeIWant(iwant)
# write messages to protobuf
iwant.finish()
pb.write(initProtoField(2, iwant))
if control.graft.len > 0:
var graft = initProtoBuffer()
for g in control.graft:
g.encodeGraft(graft)
# write messages to protobuf
graft.finish()
pb.write(initProtoField(3, graft))
if control.prune.len > 0:
var prune = initProtoBuffer()
for p in control.prune:
p.encodePrune(prune)
# write messages to protobuf
prune.finish()
pb.write(initProtoField(4, prune))
proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} =
trace "decoding control submessage"
var control: ControlMessage
while true:
var field = pb.enterSubMessage()
trace "processing submessage", field = field
case field:
of 0:
trace "no submessage found in Control msg"
break
of 1:
control.ihave &= pb.decodeIHave()
of 2:
control.iwant &= pb.decodeIWant()
of 3:
control.graft &= pb.decodeGraft()
of 4:
control.prune &= pb.decodePrune()
else:
raise newException(CatchableError, "message type not recognized")
if result.isNone:
result = some(control)
proc encodeSubs*(subs: SubOpts, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, subs.subscribe))
pb.write(initProtoField(2, subs.topic))
proc decodeSubs*(pb: var ProtoBuffer): seq[SubOpts] {.gcsafe.} =
while true:
var subOpt: SubOpts
var subscr: uint
discard pb.getVarintValue(1, subscr)
subOpt.subscribe = cast[bool](subscr)
trace "read subscribe field", subscribe = subOpt.subscribe
if pb.getString(2, subOpt.topic) < 0:
break
trace "read subscribe field", topicName = subOpt.topic
result.add(subOpt)
trace "got subscriptions", subscriptions = result
proc encodeMessage*(msg: Message, pb: var ProtoBuffer) {.gcsafe.} =
pb.write(initProtoField(1, msg.fromPeer.getBytes()))
pb.write(initProtoField(2, msg.data))
pb.write(initProtoField(3, msg.seqno))
for t in msg.topicIDs:
pb.write(initProtoField(4, t))
if msg.signature.len > 0:
pb.write(initProtoField(5, msg.signature))
if msg.key.len > 0:
pb.write(initProtoField(6, msg.key))
proc write*(pb: var ProtoBuffer, field: int, subs: SubOpts) =
var ipb = initProtoBuffer()
ipb.write(1, uint64(subs.subscribe))
ipb.write(2, subs.topic)
ipb.finish()
pb.write(field, ipb)
proc encodeMessage*(msg: Message): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, msg.fromPeer)
pb.write(2, msg.data)
pb.write(3, msg.seqno)
for topic in msg.topicIDs:
pb.write(4, topic)
if len(msg.signature) > 0:
pb.write(5, msg.signature)
if len(msg.key) > 0:
pb.write(6, msg.key)
pb.finish()
pb.buffer
proc decodeMessages*(pb: var ProtoBuffer): seq[Message] {.gcsafe.} =
# TODO: which of this fields are really optional?
while true:
var msg: Message
var fromPeer: seq[byte]
if pb.getBytes(1, fromPeer) < 0:
break
try:
msg.fromPeer = PeerID.init(fromPeer).tryGet()
except CatchableError as err:
debug "Invalid fromPeer in message", msg = err.msg
break
proc write*(pb: var ProtoBuffer, field: int, msg: Message) =
pb.write(field, encodeMessage(msg))
trace "read message field", fromPeer = msg.fromPeer.pretty
proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} =
trace "decodeGraft: decoding message"
var control = ControlGraft()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodeGraft: read topicId", topic_id = topicId
else:
trace "decodeGraft: topicId is missing"
control
if pb.getBytes(2, msg.data) < 0:
break
trace "read message field", data = msg.data.shortLog
proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} =
trace "decodePrune: decoding message"
var control = ControlPrune()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodePrune: read topicId", topic_id = topicId
else:
trace "decodePrune: topicId is missing"
control
if pb.getBytes(3, msg.seqno) < 0:
break
trace "read message field", seqno = msg.seqno.shortLog
proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} =
trace "decodeIHave: decoding message"
var control = ControlIHave()
var topicId: string
if pb.getField(1, topicId):
control.topicId = topicId
trace "decodeIHave: read topicId", topic_id = topicId
else:
trace "decodeIHave: topicId is missing"
if pb.getRepeatedField(2, control.messageIDs):
trace "decodeIHave: read messageIDs", message_ids = control.messageIDs
else:
trace "decodeIHave: no messageIDs"
control
var topic: string
while true:
if pb.getString(4, topic) < 0:
break
msg.topicIDs.add(topic)
trace "read message field", topicName = topic
topic = ""
proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} =
trace "decodeIWant: decoding message"
var control = ControlIWant()
if pb.getRepeatedField(1, control.messageIDs):
trace "decodeIWant: read messageIDs", message_ids = control.messageIDs
else:
trace "decodeIWant: no messageIDs"
discard pb.getBytes(5, msg.signature)
trace "read message field", signature = msg.signature.shortLog
proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} =
trace "decodeControl: decoding message"
var buffer: seq[byte]
if pb.getField(3, buffer):
var control: ControlMessage
var cpb = initProtoBuffer(buffer)
var ihavepbs: seq[seq[byte]]
var iwantpbs: seq[seq[byte]]
var graftpbs: seq[seq[byte]]
var prunepbs: seq[seq[byte]]
discard pb.getBytes(6, msg.key)
trace "read message field", key = msg.key.shortLog
discard cpb.getRepeatedField(1, ihavepbs)
discard cpb.getRepeatedField(2, iwantpbs)
discard cpb.getRepeatedField(3, graftpbs)
discard cpb.getRepeatedField(4, prunepbs)
result.add(msg)
for item in ihavepbs:
control.ihave.add(decodeIHave(initProtoBuffer(item)))
for item in iwantpbs:
control.iwant.add(decodeIWant(initProtoBuffer(item)))
for item in graftpbs:
control.graft.add(decodeGraft(initProtoBuffer(item)))
for item in prunepbs:
control.prune.add(decodePrune(initProtoBuffer(item)))
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} =
result = initProtoBuffer()
trace "encoding msg: ", msg = msg.shortLog
trace "decodeControl: "
some(control)
else:
none[ControlMessage]()
if msg.subscriptions.len > 0:
for s in msg.subscriptions:
var subs = initProtoBuffer()
encodeSubs(s, subs)
# write subscriptions to protobuf
subs.finish()
result.write(initProtoField(1, subs))
proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} =
trace "decodeSubscription: decoding message"
var subflag: uint64
var sub = SubOpts()
if pb.getField(1, subflag):
sub.subscribe = bool(subflag)
trace "decodeSubscription: read subscribe", subscribe = subflag
else:
trace "decodeSubscription: subscribe is missing"
if pb.getField(2, sub.topic):
trace "decodeSubscription: read topic", topic = sub.topic
else:
trace "decodeSubscription: topic is missing"
if msg.messages.len > 0:
var messages = initProtoBuffer()
for m in msg.messages:
encodeMessage(m, messages)
sub
# write messages to protobuf
messages.finish()
result.write(initProtoField(2, messages))
proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} =
trace "decodeSubscriptions: decoding message"
var subpbs: seq[seq[byte]]
var subs: seq[SubOpts]
if pb.getRepeatedField(1, subpbs):
trace "decodeSubscriptions: read subscriptions", count = len(subpbs)
for item in subpbs:
let sub = decodeSubscription(initProtoBuffer(item))
subs.add(sub)
if msg.control.isSome:
var control = initProtoBuffer()
msg.control.get.encodeControl(control)
if len(subs) == 0:
trace "decodeSubscription: no subscriptions found"
# write messages to protobuf
control.finish()
result.write(initProtoField(3, control))
subs
if result.buffer.len > 0:
result.finish()
proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} =
trace "decodeMessage: decoding message"
var msg: Message
if pb.getField(1, msg.fromPeer):
trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty()
else:
trace "decodeMessage: fromPeer is missing"
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} =
if pb.getField(2, msg.data):
trace "decodeMessage: read data", data = msg.data.shortLog()
else:
trace "decodeMessage: data is missing"
if pb.getField(3, msg.seqno):
trace "decodeMessage: read seqno", seqno = msg.data.shortLog()
else:
trace "decodeMessage: seqno is missing"
if pb.getRepeatedField(4, msg.topicIDs):
trace "decodeMessage: read topics", topic_ids = msg.topicIDs
else:
trace "decodeMessage: topics are missing"
if pb.getField(5, msg.signature):
trace "decodeMessage: read signature", signature = msg.signature.shortLog()
else:
trace "decodeMessage: signature is missing"
if pb.getField(6, msg.key):
trace "decodeMessage: read public key", key = msg.key.shortLog()
else:
trace "decodeMessage: public key is missing"
msg
proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} =
trace "decodeMessages: decoding message"
var msgpbs: seq[seq[byte]]
var msgs: seq[Message]
if pb.getRepeatedField(2, msgpbs):
trace "decodeMessages: read messages", count = len(msgpbs)
for item in msgpbs:
let msg = decodeMessage(initProtoBuffer(item))
msgs.add(msg)
if len(msgs) == 0:
trace "decodeMessages: no messages found"
msgs
proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer =
trace "encodeRpcMsg: encoding message", msg = msg.shortLog()
var pb = initProtoBuffer()
for item in msg.subscriptions:
pb.write(1, item)
for item in msg.messages:
pb.write(2, item)
if msg.control.isSome():
pb.write(3, msg.control.get())
if len(pb.buffer) > 0:
pb.finish()
result = pb
proc decodeRpcMsg*(msg: seq[byte]): RPCMsg =
trace "decodeRpcMsg: decoding message", msg = msg.shortLog()
var pb = initProtoBuffer(msg)
var rpcMsg: RPCMsg
rpcMsg.messages = pb.decodeMessages()
rpcMsg.subscriptions = pb.decodeSubscriptions()
rpcMsg.control = pb.decodeControl()
while true:
# decode SubOpts array
var field = pb.enterSubMessage()
trace "processing submessage", field = field
case field:
of 0:
trace "no submessage found in RPC msg"
break
of 1:
result.subscriptions &= pb.decodeSubs()
of 2:
result.messages &= pb.decodeMessages()
of 3:
result.control = pb.decodeControl()
else:
raise newException(CatchableError, "message type not recognized")
rpcMsg

View File

@ -12,15 +12,13 @@ import chronicles
import bearssl
import stew/[endians2, byteutils]
import nimcrypto/[utils, sha2, hmac]
import ../../stream/lpstream
import ../../stream/[connection, streamseq]
import ../../peerid
import ../../peerinfo
import ../../protobuf/minprotobuf
import ../../utility
import ../../stream/lpstream
import secure,
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf],
../../stream/bufferstream
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf]
logScope:
topics = "noise"
@ -34,7 +32,7 @@ const
ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
# Empty is a special value which indicates k has not yet been initialized.
EmptyKey: ChaChaPolyKey = [0.byte, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
EmptyKey = default(ChaChaPolyKey)
NonceMax = uint64.high - 1 # max is reserved
NoiseSize = 32
MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len)
@ -72,7 +70,7 @@ type
Noise* = ref object of Secure
rng: ref BrHmacDrbgContext
localPrivateKey: PrivateKey
localPublicKey: PublicKey
localPublicKey: seq[byte]
noiseKeys: KeyPair
commonPrologue: seq[byte]
outgoing: bool
@ -89,7 +87,7 @@ type
# Utility
proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair =
result.privateKey = Curve25519Key.random(rng).tryGet()
result.privateKey = Curve25519Key.random(rng)
result.publicKey = result.privateKey.public()
proc hashProtocol(name: string): MDigest[256] =
@ -110,12 +108,11 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key =
proc hasKey(cs: CipherState): bool =
cs.k != EmptyKey
proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] =
proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
var
tag: ChaChaPolyTag
nonce: ChaChaPolyNonce
np = cast[ptr uint64](addr nonce[4])
np[] = state.n
nonce[4..<12] = toBytesLE(state.n)
result = @data
ChaChaPoly.encrypt(state.k, nonce, tag, result, ad)
inc state.n
@ -124,13 +121,12 @@ proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte]
result &= tag
trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1
proc decryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] =
proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] =
var
tagIn = data[^ChaChaPolyTag.len..data.high].intoChaChaPolyTag
tagOut = tagIn
tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag
tagOut: ChaChaPolyTag
nonce: ChaChaPolyNonce
np = cast[ptr uint64](addr nonce[4])
np[] = state.n
nonce[4..<12] = toBytesLE(state.n)
result = data[0..(data.high - ChaChaPolyTag.len)]
ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad)
trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n
@ -156,7 +152,7 @@ proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) =
ss.cs = CipherState(k: temp_keys[1])
trace "mixKey", key = ss.cs.k.shortLog
proc mixHash(ss: var SymmetricState; data: openarray[byte]) =
proc mixHash(ss: var SymmetricState; data: openArray[byte]) =
var ctx: sha256
ctx.init()
ctx.update(ss.h.data)
@ -165,7 +161,7 @@ proc mixHash(ss: var SymmetricState; data: openarray[byte]) =
trace "mixHash", hash = ss.h.data.shortLog
# We might use this for other handshake patterns/tokens
proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} =
proc mixKeyAndHash(ss: var SymmetricState; ikm: openArray[byte]) {.used.} =
var
temp_keys: array[3, ChaChaPolyKey]
sha256.hkdf(ss.ck, ikm, [], temp_keys)
@ -173,7 +169,7 @@ proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} =
ss.mixHash(temp_keys[1])
ss.cs = CipherState(k: temp_keys[2])
proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
# according to spec if key is empty leave plaintext
if ss.cs.hasKey:
result = ss.cs.encryptWithAd(ss.h.data, data)
@ -181,7 +177,7 @@ proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
result = @data
ss.mixHash(result)
proc decryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] =
proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] =
# according to spec if key is empty leave plaintext
if ss.cs.hasKey:
result = ss.cs.decryptWithAd(ss.h.data, data)
@ -202,13 +198,13 @@ template write_e: untyped =
trace "noise write e"
# Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key).
hs.e = genKeyPair(p.rng[])
msg &= hs.e.publicKey
msg.add hs.e.publicKey
hs.ss.mixHash(hs.e.publicKey)
template write_s: untyped =
trace "noise write s"
# Appends EncryptAndHash(s.public_key) to the buffer.
msg &= hs.ss.encryptAndHash(hs.s.publicKey)
msg.add hs.ss.encryptAndHash(hs.s.publicKey)
template dh_ee: untyped =
trace "noise dh ee"
@ -244,8 +240,8 @@ template read_e: untyped =
raise newException(NoiseHandshakeError, "Noise E, expected more data")
# Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
hs.re[0..Curve25519Key.high] = msg[0..Curve25519Key.high]
msg = msg[Curve25519Key.len..msg.high]
hs.re[0..Curve25519Key.high] = msg.toOpenArray(0, Curve25519Key.high)
msg.consume(Curve25519Key.len)
hs.ss.mixHash(hs.re)
template read_s: untyped =
@ -253,30 +249,33 @@ template read_s: untyped =
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise.
# Sets rs (which must be empty) to DecryptAndHash(temp).
let
temp =
rsLen =
if hs.ss.cs.hasKey:
if msg.len < Curve25519Key.len + ChaChaPolyTag.len:
raise newException(NoiseHandshakeError, "Noise S, expected more data")
msg[0..Curve25519Key.high + ChaChaPolyTag.len]
Curve25519Key.len + ChaChaPolyTag.len
else:
if msg.len < Curve25519Key.len:
raise newException(NoiseHandshakeError, "Noise S, expected more data")
msg[0..Curve25519Key.high]
msg = msg[temp.len..msg.high]
let plain = hs.ss.decryptAndHash(temp)
hs.rs[0..Curve25519Key.high] = plain
Curve25519Key.len
hs.rs[0..Curve25519Key.high] =
hs.ss.decryptAndHash(msg.toOpenArray(0, rsLen - 1))
msg.consume(rsLen)
proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} =
var besize: array[2, byte]
await sconn.readExactly(addr besize[0], besize.len)
let size = uint16.fromBytesBE(besize).int
trace "receiveHSMessage", size
if size == 0:
return
var buffer = newSeq[byte](size)
if buffer.len > 0:
await sconn.readExactly(addr buffer[0], buffer.len)
await sconn.readExactly(addr buffer[0], buffer.len)
return buffer
proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] =
var
lesize = buf.len.uint16
besize = lesize.toBytesBE
@ -284,97 +283,106 @@ proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} =
trace "sendHSMessage", size = lesize
outbuf &= besize
outbuf &= buf
await sconn.write(outbuf)
sconn.write(outbuf)
proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} =
proc handshakeXXOutbound(
p: Noise, conn: Connection,
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
const initiator = true
var
hs = HandshakeState.init()
p2psecret = p2pProof.buffer
hs.ss.mixHash(p.commonPrologue)
hs.s = p.noiseKeys
try:
# -> e
var msg: seq[byte]
hs.ss.mixHash(p.commonPrologue)
hs.s = p.noiseKeys
write_e()
# -> e
var msg: StreamSeq
# IK might use this btw!
msg &= hs.ss.encryptAndHash(@[])
write_e()
await conn.sendHSMessage(msg)
# IK might use this btw!
msg.add hs.ss.encryptAndHash([])
# <- e, ee, s, es
await conn.sendHSMessage(msg.data)
msg = await conn.receiveHSMessage()
# <- e, ee, s, es
read_e()
dh_ee()
read_s()
dh_es()
msg.assign(await conn.receiveHSMessage())
let remoteP2psecret = hs.ss.decryptAndHash(msg)
read_e()
dh_ee()
read_s()
dh_es()
# -> s, se
let remoteP2psecret = hs.ss.decryptAndHash(msg.data)
msg.clear()
msg.setLen(0)
# -> s, se
write_s()
dh_se()
write_s()
dh_se()
# last payload must follow the ecrypted way of sending
msg &= hs.ss.encryptAndHash(p2psecret)
# last payload must follow the encrypted way of sending
msg.add hs.ss.encryptAndHash(p2psecret)
await conn.sendHSMessage(msg)
await conn.sendHSMessage(msg.data)
let (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
let (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
finally:
burnMem(hs)
proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} =
proc handshakeXXInbound(
p: Noise, conn: Connection,
p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} =
const initiator = false
var
hs = HandshakeState.init()
p2psecret = p2pProof.buffer
hs.ss.mixHash(p.commonPrologue)
hs.s = p.noiseKeys
try:
hs.ss.mixHash(p.commonPrologue)
hs.s = p.noiseKeys
# -> e
# -> e
var msg = await conn.receiveHSMessage()
var msg: StreamSeq
msg.add(await conn.receiveHSMessage())
read_e()
read_e()
# we might use this early data one day, keeping it here for clarity
let earlyData {.used.} = hs.ss.decryptAndHash(msg)
# we might use this early data one day, keeping it here for clarity
let earlyData {.used.} = hs.ss.decryptAndHash(msg.data)
# <- e, ee, s, es
# <- e, ee, s, es
msg.setLen(0)
msg.consume(msg.len)
write_e()
dh_ee()
write_s()
dh_es()
write_e()
dh_ee()
write_s()
dh_es()
msg &= hs.ss.encryptAndHash(p2psecret)
msg.add hs.ss.encryptAndHash(p2psecret)
await conn.sendHSMessage(msg)
await conn.sendHSMessage(msg.data)
msg.clear()
# -> s, se
# -> s, se
msg = await conn.receiveHSMessage()
msg.add(await conn.receiveHSMessage())
read_s()
dh_se()
read_s()
dh_se()
let remoteP2psecret = hs.ss.decryptAndHash(msg)
let (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
let
remoteP2psecret = hs.ss.decryptAndHash(msg.data)
(cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)
finally:
burnMem(hs)
method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
while true: # Discard 0-length payloads
@ -399,7 +407,8 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.
while left > 0:
let
chunkSize = if left > MaxPlainSize: MaxPlainSize else: left
cipher = sconn.writeCs.encryptWithAd([], message.toOpenArray(offset, offset + chunkSize - 1))
cipher = sconn.writeCs.encryptWithAd(
[], message.toOpenArray(offset, offset + chunkSize - 1))
left = left - chunkSize
offset = offset + chunkSize
var
@ -421,65 +430,75 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon
var
libp2pProof = initProtoBuffer()
libp2pProof.write(initProtoField(1, p.localPublicKey.getBytes.tryGet()))
libp2pProof.write(initProtoField(2, signedPayload.getBytes()))
libp2pProof.write(1, p.localPublicKey)
libp2pProof.write(2, signedPayload.getBytes())
# data field also there but not used!
libp2pProof.finish()
let handshakeRes =
var handshakeRes =
if initiator:
await handshakeXXOutbound(p, conn, libp2pProof)
await handshakeXXOutbound(p, conn, libp2pProof.buffer)
else:
await handshakeXXInbound(p, conn, libp2pProof)
await handshakeXXInbound(p, conn, libp2pProof.buffer)
var
remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret)
remotePubKey: PublicKey
remotePubKeyBytes: seq[byte]
remoteSig: Signature
remoteSigBytes: seq[byte]
var secure = try:
var
remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret)
remotePubKey: PublicKey
remotePubKeyBytes: seq[byte]
remoteSig: Signature
remoteSigBytes: seq[byte]
if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0:
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if remoteProof.getLengthValue(2, remoteSigBytes) <= 0:
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not(remoteProof.getField(1, remotePubKeyBytes)):
raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not(remoteProof.getField(2, remoteSigBytes)):
raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remotePubKey.init(remotePubKeyBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remoteSig.init(remoteSigBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remotePubKey.init(remotePubKeyBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
if not remoteSig.init(remoteSigBytes):
raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")")
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
if not remoteSig.verify(verifyPayload, remotePubKey):
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
else:
trace "Remote signature verified", peer = $conn
let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes
if not remoteSig.verify(verifyPayload, remotePubKey):
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
else:
trace "Remote signature verified", peer = $conn
if initiator and not isNil(conn.peerInfo):
let pid = PeerID.init(remotePubKey)
if not conn.peerInfo.peerId.validate():
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
if pid.isErr or pid.get() != conn.peerInfo.peerId:
var
failedKey: PublicKey
discard extractPublicKey(conn.peerInfo.peerId, failedKey)
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
if initiator and not isNil(conn.peerInfo):
let pid = PeerID.init(remotePubKey)
if not conn.peerInfo.peerId.validate():
raise newException(NoiseHandshakeError, "Failed to validate peerId.")
if pid.isErr or pid.get() != conn.peerInfo.peerId:
var
failedKey: PublicKey
discard extractPublicKey(conn.peerInfo.peerId, failedKey)
debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey
raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId)
var secure = NoiseConnection.init(conn,
PeerInfo.init(remotePubKey),
conn.observedAddr)
if initiator:
secure.readCs = handshakeRes.cs2
secure.writeCs = handshakeRes.cs1
else:
secure.readCs = handshakeRes.cs1
secure.writeCs = handshakeRes.cs2
var tmp = NoiseConnection.init(
conn, PeerInfo.init(remotePubKey), conn.observedAddr)
if initiator:
tmp.readCs = handshakeRes.cs2
tmp.writeCs = handshakeRes.cs1
else:
tmp.readCs = handshakeRes.cs1
tmp.writeCs = handshakeRes.cs2
tmp
finally:
burnMem(handshakeRes)
trace "Noise handshake completed!", initiator, peer = $secure.peerInfo
return secure
method close*(s: NoiseConnection) {.async.} =
await procCall SecureConn(s).close()
burnMem(s.readCs)
burnMem(s.writeCs)
method init*(p: Noise) {.gcsafe.} =
procCall Secure(p).init()
p.codec = NoiseCodec
@ -491,7 +510,7 @@ proc newNoise*(
rng: rng,
outgoing: outgoing,
localPrivateKey: privateKey,
localPublicKey: privateKey.getKey().tryGet(),
localPublicKey: privateKey.getKey().tryGet().getBytes().tryGet(),
noiseKeys: genKeyPair(rng[]),
commonPrologue: commonPrologue,
)

View File

@ -44,11 +44,11 @@ method initStream*(s: SecureConn) =
procCall Connection(s).initStream()
method close*(s: SecureConn) {.async.} =
await procCall Connection(s).close()
if not(isNil(s.stream)):
await s.stream.close()
await procCall Connection(s).close()
method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
doAssert(false, "Not implemented!")
@ -61,10 +61,9 @@ proc handleConn*(s: Secure,
conn: Connection,
initiator: bool): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn, initiator)
conn.closeEvent.wait()
.addCallback do(udata: pointer = nil):
if not(isNil(sconn)):
if not isNil(sconn):
conn.closeEvent.wait()
.addCallback do(udata: pointer = nil):
asyncCheck sconn.close()
return sconn

View File

@ -198,6 +198,15 @@ method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} =
await s.dataReadEvent.wait()
s.dataReadEvent.clear()
proc drainBuffer*(s: BufferStream) {.async.} =
## wait for all data in the buffer to be consumed
##
trace "draining buffer", len = s.len
while s.len > 0:
await s.dataReadEvent.wait()
s.dataReadEvent.clear()
method readOnce*(s: BufferStream,
pbytes: pointer,
nbytes: int):

View File

@ -7,9 +7,8 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import oids
import chronos, chronicles
import connection, ../utility
import connection
logScope:
topics = "chronosstream"
@ -75,15 +74,10 @@ method close*(s: ChronosStream) {.async.} =
if not s.isClosed:
trace "shutting down chronos stream", address = $s.client.remoteAddress(),
oid = s.oid
# TODO: the sequence here matters
# don't move it after the connections
# close bellow
if not s.client.closed():
await s.client.closeWait()
await procCall Connection(s).close()
except CancelledError as exc:
raise exc
except CatchableError as exc:

View File

@ -102,22 +102,32 @@ method readOnce*(s: LPStream,
doAssert(false, "not implemented!")
proc readExactly*(s: LPStream,
pbytes: pointer,
nbytes: int):
Future[void] {.async.} =
pbytes: pointer,
nbytes: int):
Future[void] {.async.} =
if s.atEof:
raise newLPStreamEOFError()
logScope:
nbytes = nbytes
obName = s.objName
stack = getStackTrace()
oid = $s.oid
var pbuffer = cast[ptr UncheckedArray[byte]](pbytes)
var read = 0
while read < nbytes and not(s.atEof()):
read += await s.readOnce(addr pbuffer[read], nbytes - read)
if read < nbytes:
trace "incomplete data received", read
raise newLPStreamIncompleteError()
proc readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] {.async, deprecated: "todo".} =
proc readLine*(s: LPStream,
limit = 0,
sep = "\r\n"): Future[string]
{.async, deprecated: "todo".} =
# TODO replace with something that exploits buffering better
var lim = if limit <= 0: -1 else: limit
var state = 0

View File

@ -61,6 +61,11 @@ template data*(v: StreamSeq): openArray[byte] =
# TODO a double-hash comment here breaks compile (!)
v.buf.toOpenArray(v.rpos, v.wpos - 1)
template toOpenArray*(v: StreamSeq, b, e: int): openArray[byte] =
# Data that is ready to be consumed
# TODO a double-hash comment here breaks compile (!)
v.buf.toOpenArray(v.rpos + b, v.rpos + e - b)
func consume*(v: var StreamSeq, n: int) =
## Mark `n` bytes that were returned via `data` as consumed
v.rpos += n
@ -71,3 +76,10 @@ func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int =
copyMem(addr buf[0], addr v.buf[v.rpos], bytes)
v.consume(bytes)
bytes
func clear*(v: var StreamSeq) =
v.consume(v.len)
func assign*(v: var StreamSeq, buf: openArray[byte]) =
v.clear()
v.add(buf)

View File

@ -10,7 +10,6 @@
import tables,
sequtils,
options,
strformat,
sets,
algorithm,
oids
@ -20,19 +19,17 @@ import chronos,
metrics
import stream/connection,
stream/chronosstream,
transports/transport,
multistream,
multiaddress,
protocols/protocol,
protocols/secure/secure,
protocols/secure/plaintext, # for plain text
peerinfo,
protocols/identify,
protocols/pubsub/pubsub,
muxers/muxer,
errors,
peerid
peerid,
errors
logScope:
topics = "switch"
@ -307,12 +304,11 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
conn.peerInfo.close()
finally:
await conn.close()
libp2p_peers.set(s.connections.len.int64)
if lock.locked():
lock.release()
libp2p_peers.set(s.connections.len.int64)
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
let connections = s.connections.getOrDefault(peer.id)
for connHolder in connections:

View File

@ -16,9 +16,6 @@ import transport,
../stream/connection,
../stream/chronosstream
when chronicles.enabledLogLevel == LogLevel.TRACE:
import oids
logScope:
topics = "tcptransport"
@ -74,7 +71,7 @@ proc connHandler*(t: TcpTransport,
proc cleanup() {.async.} =
try:
await client.join()
trace "cleaning up client", addrs = client.remoteAddress, connoid = conn.oid
trace "cleaning up client", addrs = $client.remoteAddress, connoid = conn.oid
if not(isNil(conn)):
await conn.close()
t.clients.keepItIf(it != client)

View File

@ -7,12 +7,11 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import sequtils, tables
import sequtils
import chronos, chronicles
import ../stream/connection,
../multiaddress,
../multicodec,
../errors
../multicodec
type
ConnHandler* = proc (conn: Connection): Future[void] {.gcsafe.}

View File

@ -9,7 +9,7 @@
{.used.}
import unittest, sequtils, options, tables, sets
import unittest, sequtils, options, tables
import chronos, stew/byteutils
import utils,
../../libp2p/[errors,
@ -18,8 +18,7 @@ import utils,
crypto/crypto,
protocols/pubsub/pubsub,
protocols/pubsub/floodsub,
protocols/pubsub/rpc/messages,
protocols/pubsub/rpc/message]
protocols/pubsub/rpc/messages]
import ../helpers
@ -29,7 +28,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
var ceil = 15
let fsub = cast[FloodSub](sender.pubSub.get())
while not fsub.floodsub.hasKey(key) or
not fsub.floodsub[key].contains(receiver.peerInfo.id):
not fsub.floodsub.hasPeerID(key, receiver.peerInfo.id):
await sleepAsync(100.millis)
dec ceil
doAssert(ceil > 0, "waitSub timeout!")

View File

@ -29,17 +29,19 @@ suite "GossipSub internal":
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo())
let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
var conns = newSeq[Connection]()
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
for i in 0..<15:
let conn = newBufferStream(noop)
conns &= conn
let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].conn = conn
gossipSub.mesh[topic].incl(peerInfo.id)
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.conn = conn
gossipSub.peers[peerInfo.id] = peer
gossipSub.mesh[topic].incl(peer)
check gossipSub.peers.len == 15
await gossipSub.rebalanceMesh(topic)
@ -57,18 +59,20 @@ suite "GossipSub internal":
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo())
let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
gossipSub.topics[topic] = Topic() # has to be in topics to rebalance
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
var conns = newSeq[Connection]()
for i in 0..<15:
let conn = newBufferStream(noop)
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].conn = conn
gossipSub.mesh[topic].incl(peerInfo.id)
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.conn = conn
gossipSub.peers[peerInfo.id] = peer
gossipSub.mesh[topic].incl(peer)
check gossipSub.mesh[topic].len == 15
await gossipSub.rebalanceMesh(topic)
@ -89,7 +93,7 @@ suite "GossipSub internal":
discard
let topic = "foobar"
gossipSub.gossipsub[topic] = initHashSet[string]()
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
var conns = newSeq[Connection]()
for i in 0..<15:
@ -97,10 +101,9 @@ suite "GossipSub internal":
conns &= conn
var peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
gossipSub.peers[peerInfo.id].topics &= topic
gossipSub.gossipsub[topic].incl(peerInfo.id)
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
gossipSub.gossipsub[topic].incl(peer)
check gossipSub.gossipsub[topic].len == 15
gossipSub.replenishFanout(topic)
@ -121,7 +124,7 @@ suite "GossipSub internal":
discard
let topic = "foobar"
gossipSub.fanout[topic] = initHashSet[string]()
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
gossipSub.lastFanoutPubSub[topic] = Moment.fromNow(1.millis)
await sleepAsync(5.millis) # allow the topic to expire
@ -131,13 +134,13 @@ suite "GossipSub internal":
conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
gossipSub.fanout[topic].incl(peerInfo.id)
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
gossipSub.fanout[topic].incl(peer)
check gossipSub.fanout[topic].len == GossipSubD
await gossipSub.dropFanoutPeers()
gossipSub.dropFanoutPeers()
check topic notin gossipSub.fanout
await allFuturesThrowing(conns.mapIt(it.close()))
@ -156,8 +159,8 @@ suite "GossipSub internal":
let topic1 = "foobar1"
let topic2 = "foobar2"
gossipSub.fanout[topic1] = initHashSet[string]()
gossipSub.fanout[topic2] = initHashSet[string]()
gossipSub.fanout[topic1] = initHashSet[PubSubPeer]()
gossipSub.fanout[topic2] = initHashSet[PubSubPeer]()
gossipSub.lastFanoutPubSub[topic1] = Moment.fromNow(1.millis)
gossipSub.lastFanoutPubSub[topic2] = Moment.fromNow(1.minutes)
await sleepAsync(5.millis) # allow the topic to expire
@ -168,15 +171,15 @@ suite "GossipSub internal":
conns &= conn
let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
gossipSub.fanout[topic1].incl(peerInfo.id)
gossipSub.fanout[topic2].incl(peerInfo.id)
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
gossipSub.fanout[topic1].incl(peer)
gossipSub.fanout[topic2].incl(peer)
check gossipSub.fanout[topic1].len == GossipSubD
check gossipSub.fanout[topic2].len == GossipSubD
await gossipSub.dropFanoutPeers()
gossipSub.dropFanoutPeers()
check topic1 notin gossipSub.fanout
check topic2 in gossipSub.fanout
@ -195,9 +198,9 @@ suite "GossipSub internal":
discard
let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.fanout[topic] = initHashSet[string]()
gossipSub.gossipsub[topic] = initHashSet[string]()
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
var conns = newSeq[Connection]()
# generate mesh and fanout peers
@ -206,12 +209,12 @@ suite "GossipSub internal":
conns &= conn
let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
if i mod 2 == 0:
gossipSub.fanout[topic].incl(peerInfo.id)
gossipSub.fanout[topic].incl(peer)
else:
gossipSub.mesh[topic].incl(peerInfo.id)
gossipSub.mesh[topic].incl(peer)
# generate gossipsub (free standing) peers
for i in 0..<15:
@ -219,9 +222,9 @@ suite "GossipSub internal":
conns &= conn
let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
gossipSub.gossipsub[topic].incl(peerInfo.id)
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
gossipSub.gossipsub[topic].incl(peer)
# generate messages
for i in 0..5:
@ -239,8 +242,8 @@ suite "GossipSub internal":
let peers = gossipSub.getGossipPeers()
check peers.len == GossipSubD
for p in peers.keys:
check p notin gossipSub.fanout[topic]
check p notin gossipSub.mesh[topic]
check not gossipSub.fanout.hasPeerID(topic, p)
check not gossipSub.mesh.hasPeerID(topic, p)
await allFuturesThrowing(conns.mapIt(it.close()))
@ -257,20 +260,20 @@ suite "GossipSub internal":
discard
let topic = "foobar"
gossipSub.fanout[topic] = initHashSet[string]()
gossipSub.gossipsub[topic] = initHashSet[string]()
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
var conns = newSeq[Connection]()
for i in 0..<30:
let conn = newBufferStream(noop)
conns &= conn
let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
if i mod 2 == 0:
gossipSub.fanout[topic].incl(peerInfo.id)
gossipSub.fanout[topic].incl(peer)
else:
gossipSub.gossipsub[topic].incl(peerInfo.id)
gossipSub.gossipsub[topic].incl(peer)
# generate messages
for i in 0..5:
@ -299,20 +302,20 @@ suite "GossipSub internal":
discard
let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.gossipsub[topic] = initHashSet[string]()
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]()
var conns = newSeq[Connection]()
for i in 0..<30:
let conn = newBufferStream(noop)
conns &= conn
let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
if i mod 2 == 0:
gossipSub.mesh[topic].incl(peerInfo.id)
gossipSub.mesh[topic].incl(peer)
else:
gossipSub.gossipsub[topic].incl(peerInfo.id)
gossipSub.gossipsub[topic].incl(peer)
# generate messages
for i in 0..5:
@ -341,20 +344,20 @@ suite "GossipSub internal":
discard
let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[string]()
gossipSub.fanout[topic] = initHashSet[string]()
gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
gossipSub.fanout[topic] = initHashSet[PubSubPeer]()
var conns = newSeq[Connection]()
for i in 0..<30:
let conn = newBufferStream(noop)
conns &= conn
let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo
gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec_11)
gossipSub.peers[peerInfo.id].handler = handler
let peer = newPubSubPeer(peerInfo, GossipSubCodec)
peer.handler = handler
if i mod 2 == 0:
gossipSub.mesh[topic].incl(peerInfo.id)
gossipSub.mesh[topic].incl(peer)
else:
gossipSub.fanout[topic].incl(peerInfo.id)
gossipSub.fanout[topic].incl(peer)
# generate messages
for i in 0..5:

View File

@ -32,11 +32,11 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
var ceil = 15
let fsub = GossipSub(sender.pubSub.get())
while (not fsub.gossipsub.hasKey(key) or
not fsub.gossipsub[key].contains(receiver.peerInfo.id)) and
not fsub.gossipsub.hasPeerID(key, receiver.peerInfo.id)) and
(not fsub.mesh.hasKey(key) or
not fsub.mesh[key].contains(receiver.peerInfo.id)) and
not fsub.mesh.hasPeerID(key, receiver.peerInfo.id)) and
(not fsub.fanout.hasKey(key) or
not fsub.fanout[key].contains(receiver.peerInfo.id)):
not fsub.fanout.hasPeerID(key , receiver.peerInfo.id)):
trace "waitSub sleeping..."
await sleepAsync(1.seconds)
dec ceil
@ -192,7 +192,7 @@ suite "GossipSub":
check:
"foobar" in gossip2.topics
"foobar" in gossip1.gossipsub
gossip2.peerInfo.id in gossip1.gossipsub["foobar"]
gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id)
await allFuturesThrowing(nodes.mapIt(it.stop()))
await allFuturesThrowing(awaitters)
@ -236,11 +236,11 @@ suite "GossipSub":
"foobar" in gossip1.gossipsub
"foobar" in gossip2.gossipsub
gossip2.peerInfo.id in gossip1.gossipsub["foobar"] or
gossip2.peerInfo.id in gossip1.mesh["foobar"]
gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) or
gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id)
gossip1.peerInfo.id in gossip2.gossipsub["foobar"] or
gossip1.peerInfo.id in gossip2.mesh["foobar"]
gossip2.gossipsub.hasPeerID("foobar", gossip1.peerInfo.id) or
gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.id)
await allFuturesThrowing(nodes.mapIt(it.stop()))
await allFuturesThrowing(awaitters)

677
tests/testminprotobuf.nim Normal file
View File

@ -0,0 +1,677 @@
## Nim-Libp2p
## Copyright (c) 2018 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import unittest
import ../libp2p/protobuf/minprotobuf
import stew/byteutils, strutils
when defined(nimHasUsed): {.used.}
suite "MinProtobuf test suite":
const VarintVectors = [
"0800", "0801", "08ffffffff07", "08ffffffff0f", "08ffffffffffffffff7f",
"08ffffffffffffffffff01"
]
const VarintValues = [
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xFFFF_FFFF'u64,
0x7FFF_FFFF_FFFF_FFFF'u64, 0xFFFF_FFFF_FFFF_FFFF'u64
]
const Fixed32Vectors = [
"0d00000000", "0d01000000", "0dffffff7f", "0dddccbbaa", "0dffffffff"
]
const Fixed32Values = [
0x0'u32, 0x1'u32, 0x7FFF_FFFF'u32, 0xAABB_CCDD'u32, 0xFFFF_FFFF'u32
]
const Fixed64Vectors = [
"090000000000000000", "090100000000000000", "09ffffff7f00000000",
"09ddccbbaa00000000", "09ffffffff00000000", "09ffffffffffffff7f",
"099988ffeeddccbbaa", "09ffffffffffffffff"
]
const Fixed64Values = [
0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xAABB_CCDD'u64, 0xFFFF_FFFF'u64,
0x7FFF_FFFF_FFFF_FFFF'u64, 0xAABB_CCDD_EEFF_8899'u64,
0xFFFF_FFFF_FFFF_FFFF'u64
]
const LengthVectors = [
"0a00", "0a0161", "0a026162", "0a0461626364", "0a086162636465666768"
]
const LengthValues = [
"", "a", "ab", "abcd", "abcdefgh"
]
## This vector values was tested with `protoc` and related proto file.
## syntax = "proto2";
## message testmsg {
## repeated uint64 d = 1 [packed=true];
## repeated uint64 d = 2 [packed=true];
## }
const PackedVarintVector =
"0a1f0001ffffffff07ffffffff0fffffffffffffffff7fffffffffffffffffff0112020001"
## syntax = "proto2";
## message testmsg {
## repeated sfixed32 d = 1 [packed=true];
## repeated sfixed32 d = 2 [packed=true];
## }
const PackedFixed32Vector =
"0a140000000001000000ffffff7fddccbbaaffffffff12080000000001000000"
## syntax = "proto2";
## message testmsg {
## repeated sfixed64 d = 1 [packed=true];
## repeated sfixed64 d = 2 [packed=true];
## }
const PackedFixed64Vector =
"""0a4000000000000000000100000000000000ffffff7f00000000ddccbbaa00000000
ffffffff00000000ffffffffffffff7f9988ffeeddccbbaaffffffffffffffff1210
00000000000000000100000000000000"""
proc getVarintEncodedValue(value: uint64): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getVarintDecodedValue(data: openarray[byte]): uint64 =
var value: uint64
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
value
proc getFixed32EncodedValue(value: float32): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getFixed32DecodedValue(data: openarray[byte]): uint32 =
var value: float32
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
cast[uint32](value)
proc getFixed64EncodedValue(value: float64): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getFixed64DecodedValue(data: openarray[byte]): uint64 =
var value: float64
var pb = initProtoBuffer(data)
let res = pb.getField(1, value)
doAssert(res)
cast[uint64](value)
proc getLengthEncodedValue(value: string): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getLengthEncodedValue(value: seq[byte]): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, value)
pb.finish()
return pb.buffer
proc getLengthDecodedValue(data: openarray[byte]): string =
var value = newString(len(data))
var valueLen = 0
var pb = initProtoBuffer(data)
let res = pb.getField(1, value, valueLen)
doAssert(res)
value.setLen(valueLen)
value
proc isFullZero[T: byte|char](data: openarray[T]): bool =
for ch in data:
if int(ch) != 0:
return false
return true
proc corruptHeader(data: var openarray[byte], index: int) =
var values = [3, 4, 6]
data[0] = data[0] and 0xF8'u8
data[0] = data[0] or byte(values[index mod len(values)])
test "[varint] edge values test":
for i in 0 ..< len(VarintValues):
let data = getVarintEncodedValue(VarintValues[i])
check:
toHex(data) == VarintVectors[i]
getVarintDecodedValue(data) == VarintValues[i]
test "[varint] mixing many values with same field number test":
for i in 0 ..< len(VarintValues):
var pb = initProtoBuffer()
for k in 0 ..< len(VarintValues):
let index = (i + k + 1) mod len(VarintValues)
pb.write(1, VarintValues[index])
pb.finish()
check getVarintDecodedValue(pb.buffer) == VarintValues[i]
test "[varint] incorrect values test":
for i in 0 ..< len(VarintValues):
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[varint] non-existent field test":
for i in 0 ..< len(VarintValues):
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == 0'u64
test "[varint] corrupted header test":
for i in 0 ..< len(VarintValues):
for k in 0 ..< 3:
var value: uint64
var data = getVarintEncodedValue(VarintValues[i])
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[varint] empty buffer test":
var value: uint64
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == 0'u64
test "[varint] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, VarintValues[1])
pb1.write(1, VarintValues[2])
pb1.write(2, VarintValues[3])
pb1.write(1, VarintValues[4])
pb1.write(1, VarintValues[5])
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[uint64]
var fieldarr2: seq[uint64]
var fieldarr3: seq[uint64]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
fieldarr1[0] == VarintValues[1]
fieldarr1[1] == VarintValues[2]
fieldarr1[2] == VarintValues[4]
fieldarr1[3] == VarintValues[5]
fieldarr2[0] == VarintValues[3]
test "[varint] Repeated packed field test":
var pb1 = initProtoBuffer()
pb1.writePacked(1, VarintValues)
pb1.writePacked(2, VarintValues[0 .. 1])
pb1.finish()
check:
toHex(pb1.buffer) == PackedVarintVector
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[uint64]
var fieldarr2: seq[uint64]
var fieldarr3: seq[uint64]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 6
fieldarr1[0] == VarintValues[0]
fieldarr1[1] == VarintValues[1]
fieldarr1[2] == VarintValues[2]
fieldarr1[3] == VarintValues[3]
fieldarr1[4] == VarintValues[4]
fieldarr1[5] == VarintValues[5]
fieldarr2[0] == VarintValues[0]
fieldarr2[1] == VarintValues[1]
test "[fixed32] edge values test":
for i in 0 ..< len(Fixed32Values):
let data = getFixed32EncodedValue(cast[float32](Fixed32Values[i]))
check:
toHex(data) == Fixed32Vectors[i]
getFixed32DecodedValue(data) == Fixed32Values[i]
test "[fixed32] mixing many values with same field number test":
for i in 0 ..< len(Fixed32Values):
var pb = initProtoBuffer()
for k in 0 ..< len(Fixed32Values):
let index = (i + k + 1) mod len(Fixed32Values)
pb.write(1, cast[float32](Fixed32Values[index]))
pb.finish()
check getFixed32DecodedValue(pb.buffer) == Fixed32Values[i]
test "[fixed32] incorrect values test":
for i in 0 ..< len(Fixed32Values):
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed32] non-existent field test":
for i in 0 ..< len(Fixed32Values):
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == float32(0)
test "[fixed32] corrupted header test":
for i in 0 ..< len(Fixed32Values):
for k in 0 ..< 3:
var value: float32
var data = getFixed32EncodedValue(float32(Fixed32Values[i]))
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed32] empty buffer test":
var value: float32
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == float32(0)
test "[fixed32] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, cast[float32](Fixed32Values[0]))
pb1.write(1, cast[float32](Fixed32Values[1]))
pb1.write(2, cast[float32](Fixed32Values[2]))
pb1.write(1, cast[float32](Fixed32Values[3]))
pb1.write(1, cast[float32](Fixed32Values[4]))
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float32]
var fieldarr2: seq[float32]
var fieldarr3: seq[float32]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[uint32](fieldarr1[0]) == Fixed64Values[0]
cast[uint32](fieldarr1[1]) == Fixed64Values[1]
cast[uint32](fieldarr1[2]) == Fixed64Values[3]
cast[uint32](fieldarr1[3]) == Fixed64Values[4]
cast[uint32](fieldarr2[0]) == Fixed64Values[2]
test "[fixed32] Repeated packed field test":
var pb1 = initProtoBuffer()
var values = newSeq[float32](len(Fixed32Values))
for i in 0 ..< len(values):
values[i] = cast[float32](Fixed32Values[i])
pb1.writePacked(1, values)
pb1.writePacked(2, values[0 .. 1])
pb1.finish()
check:
toHex(pb1.buffer) == PackedFixed32Vector
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float32]
var fieldarr2: seq[float32]
var fieldarr3: seq[float32]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 5
cast[uint32](fieldarr1[0]) == Fixed32Values[0]
cast[uint32](fieldarr1[1]) == Fixed32Values[1]
cast[uint32](fieldarr1[2]) == Fixed32Values[2]
cast[uint32](fieldarr1[3]) == Fixed32Values[3]
cast[uint32](fieldarr1[4]) == Fixed32Values[4]
cast[uint32](fieldarr2[0]) == Fixed32Values[0]
cast[uint32](fieldarr2[1]) == Fixed32Values[1]
test "[fixed64] edge values test":
for i in 0 ..< len(Fixed64Values):
let data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
check:
toHex(data) == Fixed64Vectors[i]
getFixed64DecodedValue(data) == Fixed64Values[i]
test "[fixed64] mixing many values with same field number test":
for i in 0 ..< len(Fixed64Values):
var pb = initProtoBuffer()
for k in 0 ..< len(Fixed64Values):
let index = (i + k + 1) mod len(Fixed64Values)
pb.write(1, cast[float64](Fixed64Values[index]))
pb.finish()
check getFixed64DecodedValue(pb.buffer) == Fixed64Values[i]
test "[fixed64] incorrect values test":
for i in 0 ..< len(Fixed64Values):
var value: float32
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed64] non-existent field test":
for i in 0 ..< len(Fixed64Values):
var value: float64
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
var pb = initProtoBuffer(data)
check:
pb.getField(2, value) == false
value == float64(0)
test "[fixed64] corrupted header test":
for i in 0 ..< len(Fixed64Values):
for k in 0 ..< 3:
var value: float64
var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i]))
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value) == false
test "[fixed64] empty buffer test":
var value: float64
var pb = initProtoBuffer()
check:
pb.getField(1, value) == false
value == float64(0)
test "[fixed64] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, cast[float64](Fixed64Values[2]))
pb1.write(1, cast[float64](Fixed64Values[3]))
pb1.write(2, cast[float64](Fixed64Values[4]))
pb1.write(1, cast[float64](Fixed64Values[5]))
pb1.write(1, cast[float64](Fixed64Values[6]))
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float64]
var fieldarr2: seq[float64]
var fieldarr3: seq[float64]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[uint64](fieldarr1[0]) == Fixed64Values[2]
cast[uint64](fieldarr1[1]) == Fixed64Values[3]
cast[uint64](fieldarr1[2]) == Fixed64Values[5]
cast[uint64](fieldarr1[3]) == Fixed64Values[6]
cast[uint64](fieldarr2[0]) == Fixed64Values[4]
test "[fixed64] Repeated packed field test":
var pb1 = initProtoBuffer()
var values = newSeq[float64](len(Fixed64Values))
for i in 0 ..< len(values):
values[i] = cast[float64](Fixed64Values[i])
pb1.writePacked(1, values)
pb1.writePacked(2, values[0 .. 1])
pb1.finish()
let expect = PackedFixed64Vector.multiReplace(("\n", ""), (" ", ""))
check:
toHex(pb1.buffer) == expect
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[float64]
var fieldarr2: seq[float64]
var fieldarr3: seq[float64]
let r1 = pb2.getPackedRepeatedField(1, fieldarr1)
let r2 = pb2.getPackedRepeatedField(2, fieldarr2)
let r3 = pb2.getPackedRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 2
len(fieldarr1) == 8
cast[uint64](fieldarr1[0]) == Fixed64Values[0]
cast[uint64](fieldarr1[1]) == Fixed64Values[1]
cast[uint64](fieldarr1[2]) == Fixed64Values[2]
cast[uint64](fieldarr1[3]) == Fixed64Values[3]
cast[uint64](fieldarr1[4]) == Fixed64Values[4]
cast[uint64](fieldarr1[5]) == Fixed64Values[5]
cast[uint64](fieldarr1[6]) == Fixed64Values[6]
cast[uint64](fieldarr1[7]) == Fixed64Values[7]
cast[uint64](fieldarr2[0]) == Fixed64Values[0]
cast[uint64](fieldarr2[1]) == Fixed64Values[1]
test "[length] edge values test":
for i in 0 ..< len(LengthValues):
let data1 = getLengthEncodedValue(LengthValues[i])
let data2 = getLengthEncodedValue(cast[seq[byte]](LengthValues[i]))
check:
toHex(data1) == LengthVectors[i]
toHex(data2) == LengthVectors[i]
check:
getLengthDecodedValue(data1) == LengthValues[i]
getLengthDecodedValue(data2) == LengthValues[i]
test "[length] mixing many values with same field number test":
for i in 0 ..< len(LengthValues):
var pb1 = initProtoBuffer()
var pb2 = initProtoBuffer()
for k in 0 ..< len(LengthValues):
let index = (i + k + 1) mod len(LengthValues)
pb1.write(1, LengthValues[index])
pb2.write(1, cast[seq[byte]](LengthValues[index]))
pb1.finish()
pb2.finish()
check getLengthDecodedValue(pb1.buffer) == LengthValues[i]
check getLengthDecodedValue(pb2.buffer) == LengthValues[i]
test "[length] incorrect values test":
for i in 0 ..< len(LengthValues):
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
# corrupting
data.setLen(len(data) - 1)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
test "[length] non-existent field test":
for i in 0 ..< len(LengthValues):
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
var pb = initProtoBuffer(data)
check:
pb.getField(2, value, valueLen) == false
valueLen == 0
test "[length] corrupted header test":
for i in 0 ..< len(LengthValues):
for k in 0 ..< 3:
var value = newSeq[byte](len(LengthValues[i]))
var valueLen = 0
var data = getLengthEncodedValue(LengthValues[i])
data.corruptHeader(k)
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
test "[length] empty buffer test":
var value = newSeq[byte](len(LengthValues[0]))
var valueLen = 0
var pb = initProtoBuffer()
check:
pb.getField(1, value, valueLen) == false
valueLen == 0
test "[length] buffer overflow test":
for i in 1 ..< len(LengthValues):
let data = getLengthEncodedValue(LengthValues[i])
var value = newString(len(LengthValues[i]) - 1)
var valueLen = 0
var pb = initProtoBuffer(data)
check:
pb.getField(1, value, valueLen) == false
valueLen == len(LengthValues[i])
isFullZero(value) == true
test "[length] mix of buffer overflow and normal fields test":
var pb1 = initProtoBuffer()
pb1.write(1, "TEST10")
pb1.write(1, "TEST20")
pb1.write(1, "TEST")
pb1.write(1, "TEST30")
pb1.write(1, "SOME")
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(4)
var valueLen = 0
check:
pb2.getField(1, value, valueLen) == true
value == "SOME"
test "[length] too big message test":
var pb1 = initProtoBuffer()
var bigString = newString(MaxMessageSize + 1)
for i in 0 ..< len(bigString):
bigString[i] = 'A'
pb1.write(1, bigString)
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var value = newString(MaxMessageSize + 1)
var valueLen = 0
check:
pb2.getField(1, value, valueLen) == false
test "[length] Repeated field test":
var pb1 = initProtoBuffer()
pb1.write(1, "TEST1")
pb1.write(1, "TEST2")
pb1.write(2, "TEST5")
pb1.write(1, "TEST3")
pb1.write(1, "TEST4")
pb1.finish()
var pb2 = initProtoBuffer(pb1.buffer)
var fieldarr1: seq[seq[byte]]
var fieldarr2: seq[seq[byte]]
var fieldarr3: seq[seq[byte]]
let r1 = pb2.getRepeatedField(1, fieldarr1)
let r2 = pb2.getRepeatedField(2, fieldarr2)
let r3 = pb2.getRepeatedField(3, fieldarr3)
check:
r1 == true
r2 == true
r3 == false
len(fieldarr3) == 0
len(fieldarr2) == 1
len(fieldarr1) == 4
cast[string](fieldarr1[0]) == "TEST1"
cast[string](fieldarr1[1]) == "TEST2"
cast[string](fieldarr1[2]) == "TEST3"
cast[string](fieldarr1[3]) == "TEST4"
cast[string](fieldarr2[0]) == "TEST5"
test "Different value types in one message with same field number test":
proc getEncodedValue(): seq[byte] =
var pb = initProtoBuffer()
pb.write(1, VarintValues[1])
pb.write(2, cast[float32](Fixed32Values[1]))
pb.write(3, cast[float64](Fixed64Values[1]))
pb.write(4, LengthValues[1])
pb.write(1, VarintValues[2])
pb.write(2, cast[float32](Fixed32Values[2]))
pb.write(3, cast[float64](Fixed64Values[2]))
pb.write(4, LengthValues[2])
pb.write(1, cast[float32](Fixed32Values[3]))
pb.write(2, cast[float64](Fixed64Values[3]))
pb.write(3, LengthValues[3])
pb.write(4, VarintValues[3])
pb.write(1, cast[float64](Fixed64Values[4]))
pb.write(2, LengthValues[4])
pb.write(3, VarintValues[4])
pb.write(4, cast[float32](Fixed32Values[4]))
pb.write(1, VarintValues[1])
pb.write(2, cast[float32](Fixed32Values[1]))
pb.write(3, cast[float64](Fixed64Values[1]))
pb.write(4, LengthValues[1])
pb.finish()
pb.buffer
let msg = getEncodedValue()
let pb = initProtoBuffer(msg)
var varintValue: uint64
var fixed32Value: float32
var fixed64Value: float64
var lengthValue = newString(10)
var lengthSize: int
check:
pb.getField(1, varintValue) == true
pb.getField(2, fixed32Value) == true
pb.getField(3, fixed64Value) == true
pb.getField(4, lengthValue, lengthSize) == true
lengthValue.setLen(lengthSize)
check:
varintValue == VarintValues[1]
cast[uint32](fixed32Value) == Fixed32Values[1]
cast[uint64](fixed64Value) == Fixed64Values[1]
lengthValue == LengthValues[1]

View File

@ -1,4 +1,4 @@
import unittest, strutils, sequtils, strformat, stew/byteutils
import unittest, strutils, strformat, stew/byteutils
import chronos
import ../libp2p/errors,
../libp2p/multistream,

View File

@ -1,4 +1,5 @@
import testvarint,
testminprotobuf,
teststreamseq
import testrsa,

View File

@ -9,19 +9,13 @@ import ../libp2p/[errors,
multistream,
standard_setup,
stream/bufferstream,
protocols/identify,
stream/connection,
transports/transport,
transports/tcptransport,
multiaddress,
peerinfo,
crypto/crypto,
protocols/protocol,
muxers/muxer,
muxers/mplex/mplex,
muxers/mplex/types,
protocols/secure/secio,
protocols/secure/secure,
stream/lpstream]
import ./helpers