Merge branch 'master' into gossip-one-one

This commit is contained in:
Giovanni Petrantoni 2020-08-12 10:57:32 +09:00
commit 90a33c0a29
31 changed files with 1529 additions and 1252 deletions

View File

@ -121,7 +121,7 @@ proc onClose(c: ConnManager, conn: Connection) {.async.} =
## triggers the connections resource cleanup ## triggers the connections resource cleanup
## ##
await conn.closeEvent.wait() await conn.join()
trace "triggering connection cleanup" trace "triggering connection cleanup"
await c.cleanupConn(conn) await c.cleanupConn(conn)

View File

@ -70,13 +70,15 @@ when supported(PKScheme.Secp256k1):
import ecnist, bearssl import ecnist, bearssl
import ../protobuf/minprotobuf, ../vbuffer, ../multihash, ../multicodec import ../protobuf/minprotobuf, ../vbuffer, ../multihash, ../multicodec
import nimcrypto/[rijndael, twofish, sha2, hash, hmac, utils] import nimcrypto/[rijndael, twofish, sha2, hash, hmac]
# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
import ../utility import ../utility
import stew/results import stew/results
export results export results
# This is workaround for Nim's `import` bug # This is workaround for Nim's `import` bug
export rijndael, twofish, sha2, hash, hmac, utils export rijndael, twofish, sha2, hash, hmac, ncrutils
from strutils import split from strutils import split
@ -514,20 +516,14 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: string): bool =
## hexadecimal string representation. ## hexadecimal string representation.
## ##
## Returns ``true`` on success. ## Returns ``true`` on success.
try: key.init(ncrutils.fromHex(data))
key.init(utils.fromHex(data))
except ValueError:
false
proc init*(sig: var Signature, data: string): bool = proc init*(sig: var Signature, data: string): bool =
## Initialize signature ``sig`` from serialized hexadecimal string ## Initialize signature ``sig`` from serialized hexadecimal string
## representation. ## representation.
## ##
## Returns ``true`` on success. ## Returns ``true`` on success.
try: sig.init(ncrutils.fromHex(data))
sig.init(utils.fromHex(data))
except ValueError:
false
proc init*(t: typedesc[PrivateKey], proc init*(t: typedesc[PrivateKey],
data: openarray[byte]): CryptoResult[PrivateKey] = data: openarray[byte]): CryptoResult[PrivateKey] =
@ -559,10 +555,7 @@ proc init*(t: typedesc[Signature],
proc init*(t: typedesc[PrivateKey], data: string): CryptoResult[PrivateKey] = proc init*(t: typedesc[PrivateKey], data: string): CryptoResult[PrivateKey] =
## Create new private key from libp2p's protobuf serialized hexadecimal string ## Create new private key from libp2p's protobuf serialized hexadecimal string
## form. ## form.
try: t.init(ncrutils.fromHex(data))
t.init(utils.fromHex(data))
except ValueError:
err(KeyError)
when supported(PKScheme.RSA): when supported(PKScheme.RSA):
proc init*(t: typedesc[PrivateKey], key: rsa.RsaPrivateKey): PrivateKey = proc init*(t: typedesc[PrivateKey], key: rsa.RsaPrivateKey): PrivateKey =
@ -591,17 +584,11 @@ when supported(PKScheme.ECDSA):
proc init*(t: typedesc[PublicKey], data: string): CryptoResult[PublicKey] = proc init*(t: typedesc[PublicKey], data: string): CryptoResult[PublicKey] =
## Create new public key from libp2p's protobuf serialized hexadecimal string ## Create new public key from libp2p's protobuf serialized hexadecimal string
## form. ## form.
try: t.init(ncrutils.fromHex(data))
t.init(utils.fromHex(data))
except ValueError:
err(KeyError)
proc init*(t: typedesc[Signature], data: string): CryptoResult[Signature] = proc init*(t: typedesc[Signature], data: string): CryptoResult[Signature] =
## Create new signature from serialized hexadecimal string form. ## Create new signature from serialized hexadecimal string form.
try: t.init(ncrutils.fromHex(data))
t.init(utils.fromHex(data))
except ValueError:
err(SigError)
proc `==`*(key1, key2: PublicKey): bool {.inline.} = proc `==`*(key1, key2: PublicKey): bool {.inline.} =
## Return ``true`` if two public keys ``key1`` and ``key2`` of the same ## Return ``true`` if two public keys ``key1`` and ``key2`` of the same
@ -709,7 +696,7 @@ func shortLog*(key: PrivateKey|PublicKey): string =
proc `$`*(sig: Signature): string = proc `$`*(sig: Signature): string =
## Get string representation of signature ``sig``. ## Get string representation of signature ``sig``.
result = toHex(sig.data) result = ncrutils.toHex(sig.data)
proc sign*(key: PrivateKey, proc sign*(key: PrivateKey,
data: openarray[byte]): CryptoResult[Signature] {.gcsafe.} = data: openarray[byte]): CryptoResult[Signature] {.gcsafe.} =

View File

@ -17,7 +17,8 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import bearssl import bearssl
import nimcrypto/utils # We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
import minasn1 import minasn1
export minasn1.Asn1Error export minasn1.Asn1Error
import stew/[results, ctops] import stew/[results, ctops]
@ -289,7 +290,7 @@ proc `$`*(seckey: EcPrivateKey): string =
result = "Corrupted key" result = "Corrupted key"
else: else:
let e = offset + cast[int](seckey.key.xlen) - 1 let e = offset + cast[int](seckey.key.xlen) - 1
result = toHex(seckey.buffer.toOpenArray(offset, e)) result = ncrutils.toHex(seckey.buffer.toOpenArray(offset, e))
proc `$`*(pubkey: EcPublicKey): string = proc `$`*(pubkey: EcPublicKey): string =
## Return string representation of EC public key. ## Return string representation of EC public key.
@ -305,14 +306,14 @@ proc `$`*(pubkey: EcPublicKey): string =
result = "Corrupted key" result = "Corrupted key"
else: else:
let e = offset + cast[int](pubkey.key.qlen) - 1 let e = offset + cast[int](pubkey.key.qlen) - 1
result = toHex(pubkey.buffer.toOpenArray(offset, e)) result = ncrutils.toHex(pubkey.buffer.toOpenArray(offset, e))
proc `$`*(sig: EcSignature): string = proc `$`*(sig: EcSignature): string =
## Return hexadecimal string representation of EC signature. ## Return hexadecimal string representation of EC signature.
if isNil(sig) or len(sig.buffer) == 0: if isNil(sig) or len(sig.buffer) == 0:
result = "Empty or uninitialized ECNIST signature" result = "Empty or uninitialized ECNIST signature"
else: else:
result = toHex(sig.buffer) result = ncrutils.toHex(sig.buffer)
proc toRawBytes*(seckey: EcPrivateKey, data: var openarray[byte]): EcResult[int] = proc toRawBytes*(seckey: EcPrivateKey, data: var openarray[byte]): EcResult[int] =
## Serialize EC private key ``seckey`` to raw binary form and store it ## Serialize EC private key ``seckey`` to raw binary form and store it
@ -708,14 +709,16 @@ proc init*(sig: var EcSignature, data: openarray[byte]): Result[void, Asn1Error]
else: else:
err(Asn1Error.Incorrect) err(Asn1Error.Incorrect)
proc init*[T: EcPKI](sospk: var T, data: string): Result[void, Asn1Error] {.inline.} = proc init*[T: EcPKI](sospk: var T,
data: string): Result[void, Asn1Error] {.inline.} =
## Initialize EC `private key`, `public key` or `signature` ``sospk`` from ## Initialize EC `private key`, `public key` or `signature` ``sospk`` from
## ASN.1 DER hexadecimal string representation ``data``. ## ASN.1 DER hexadecimal string representation ``data``.
## ##
## Procedure returns ``Asn1Status``. ## Procedure returns ``Asn1Status``.
sospk.init(fromHex(data)) sospk.init(ncrutils.fromHex(data))
proc init*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPrivateKey] = proc init*(t: typedesc[EcPrivateKey],
data: openarray[byte]): EcResult[EcPrivateKey] =
## Initialize EC private key from ASN.1 DER binary representation ``data`` and ## Initialize EC private key from ASN.1 DER binary representation ``data`` and
## return constructed object. ## return constructed object.
var key: EcPrivateKey var key: EcPrivateKey
@ -725,7 +728,8 @@ proc init*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPrivate
else: else:
ok(key) ok(key)
proc init*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPublicKey] = proc init*(t: typedesc[EcPublicKey],
data: openarray[byte]): EcResult[EcPublicKey] =
## Initialize EC public key from ASN.1 DER binary representation ``data`` and ## Initialize EC public key from ASN.1 DER binary representation ``data`` and
## return constructed object. ## return constructed object.
var key: EcPublicKey var key: EcPublicKey
@ -735,7 +739,8 @@ proc init*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPublicKe
else: else:
ok(key) ok(key)
proc init*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSignature] = proc init*(t: typedesc[EcSignature],
data: openarray[byte]): EcResult[EcSignature] =
## Initialize EC signature from raw binary representation ``data`` and ## Initialize EC signature from raw binary representation ``data`` and
## return constructed object. ## return constructed object.
var sig: EcSignature var sig: EcSignature
@ -748,10 +753,7 @@ proc init*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSignatur
proc init*[T: EcPKI](t: typedesc[T], data: string): EcResult[T] = proc init*[T: EcPKI](t: typedesc[T], data: string): EcResult[T] =
## Initialize EC `private key`, `public key` or `signature` from hexadecimal ## Initialize EC `private key`, `public key` or `signature` from hexadecimal
## string representation ``data`` and return constructed object. ## string representation ``data`` and return constructed object.
try: t.init(ncrutils.fromHex(data))
t.init(fromHex(data))
except ValueError:
err(EcKeyIncorrectError)
proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool = proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool =
## Initialize EC `private key` or `scalar` ``key`` from raw binary ## Initialize EC `private key` or `scalar` ``key`` from raw binary
@ -833,9 +835,10 @@ proc initRaw*[T: EcPKI](sospk: var T, data: string): bool {.inline.} =
## raw hexadecimal string representation ``data``. ## raw hexadecimal string representation ``data``.
## ##
## Procedure returns ``true`` on success, ``false`` otherwise. ## Procedure returns ``true`` on success, ``false`` otherwise.
result = sospk.initRaw(fromHex(data)) result = sospk.initRaw(ncrutils.fromHex(data))
proc initRaw*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPrivateKey] = proc initRaw*(t: typedesc[EcPrivateKey],
data: openarray[byte]): EcResult[EcPrivateKey] =
## Initialize EC private key from raw binary representation ``data`` and ## Initialize EC private key from raw binary representation ``data`` and
## return constructed object. ## return constructed object.
var res: EcPrivateKey var res: EcPrivateKey
@ -844,7 +847,8 @@ proc initRaw*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPriv
else: else:
ok(res) ok(res)
proc initRaw*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPublicKey] = proc initRaw*(t: typedesc[EcPublicKey],
data: openarray[byte]): EcResult[EcPublicKey] =
## Initialize EC public key from raw binary representation ``data`` and ## Initialize EC public key from raw binary representation ``data`` and
## return constructed object. ## return constructed object.
var res: EcPublicKey var res: EcPublicKey
@ -853,7 +857,8 @@ proc initRaw*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPubli
else: else:
ok(res) ok(res)
proc initRaw*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSignature] = proc initRaw*(t: typedesc[EcSignature],
data: openarray[byte]): EcResult[EcSignature] =
## Initialize EC signature from raw binary representation ``data`` and ## Initialize EC signature from raw binary representation ``data`` and
## return constructed object. ## return constructed object.
var res: EcSignature var res: EcSignature
@ -865,7 +870,7 @@ proc initRaw*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSigna
proc initRaw*[T: EcPKI](t: typedesc[T], data: string): T {.inline.} = proc initRaw*[T: EcPKI](t: typedesc[T], data: string): T {.inline.} =
## Initialize EC `private key`, `public key` or `signature` from raw ## Initialize EC `private key`, `public key` or `signature` from raw
## hexadecimal string representation ``data`` and return constructed object. ## hexadecimal string representation ``data`` and return constructed object.
result = t.initRaw(fromHex(data)) result = t.initRaw(ncrutils.fromHex(data))
proc scalarMul*(pub: EcPublicKey, sec: EcPrivateKey): EcPublicKey = proc scalarMul*(pub: EcPublicKey, sec: EcPrivateKey): EcPublicKey =
## Return scalar multiplication of ``pub`` and ``sec``. ## Return scalar multiplication of ``pub`` and ``sec``.
@ -926,7 +931,7 @@ proc getSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey): seq[byte] =
copyMem(addr result[0], addr data[0], res) copyMem(addr result[0], addr data[0], res)
proc sign*[T: byte|char](seckey: EcPrivateKey, proc sign*[T: byte|char](seckey: EcPrivateKey,
message: openarray[T]): EcResult[EcSignature] {.gcsafe.} = message: openarray[T]): EcResult[EcSignature] {.gcsafe.} =
## Get ECDSA signature of data ``message`` using private key ``seckey``. ## Get ECDSA signature of data ``message`` using private key ``seckey``.
if isNil(seckey): if isNil(seckey):
return err(EcKeyIncorrectError) return err(EcKeyIncorrectError)

View File

@ -14,7 +14,9 @@
{.push raises: Defect.} {.push raises: Defect.}
import constants, bearssl import constants, bearssl
import nimcrypto/[hash, sha2, utils] import nimcrypto/[hash, sha2]
# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
import stew/[results, ctops] import stew/[results, ctops]
export results export results
@ -1735,14 +1737,17 @@ proc `==`*(eda, edb: EdSignature): bool =
## Compare ED25519 `signature` objects for equality. ## Compare ED25519 `signature` objects for equality.
result = CT.isEqual(eda.data, edb.data) result = CT.isEqual(eda.data, edb.data)
proc `$`*(key: EdPrivateKey): string = toHex(key.data) proc `$`*(key: EdPrivateKey): string =
## Return string representation of ED25519 `private key`. ## Return string representation of ED25519 `private key`.
ncrutils.toHex(key.data)
proc `$`*(key: EdPublicKey): string = toHex(key.data) proc `$`*(key: EdPublicKey): string =
## Return string representation of ED25519 `private key`. ## Return string representation of ED25519 `private key`.
ncrutils.toHex(key.data)
proc `$`*(sig: EdSignature): string = toHex(sig.data) proc `$`*(sig: EdSignature): string =
## Return string representation of ED25519 `signature`. ## Return string representation of ED25519 `signature`.
ncrutils.toHex(sig.data)
proc init*(key: var EdPrivateKey, data: openarray[byte]): bool = proc init*(key: var EdPrivateKey, data: openarray[byte]): bool =
## Initialize ED25519 `private key` ``key`` from raw binary ## Initialize ED25519 `private key` ``key`` from raw binary
@ -1779,32 +1784,24 @@ proc init*(key: var EdPrivateKey, data: string): bool =
## representation ``data``. ## representation ``data``.
## ##
## Procedure returns ``true`` on success. ## Procedure returns ``true`` on success.
try: init(key, ncrutils.fromHex(data))
init(key, fromHex(data))
except ValueError:
false
proc init*(key: var EdPublicKey, data: string): bool = proc init*(key: var EdPublicKey, data: string): bool =
## Initialize ED25519 `public key` ``key`` from hexadecimal string ## Initialize ED25519 `public key` ``key`` from hexadecimal string
## representation ``data``. ## representation ``data``.
## ##
## Procedure returns ``true`` on success. ## Procedure returns ``true`` on success.
try: init(key, ncrutils.fromHex(data))
init(key, fromHex(data))
except ValueError:
false
proc init*(sig: var EdSignature, data: string): bool = proc init*(sig: var EdSignature, data: string): bool =
## Initialize ED25519 `signature` ``sig`` from hexadecimal string ## Initialize ED25519 `signature` ``sig`` from hexadecimal string
## representation ``data``. ## representation ``data``.
## ##
## Procedure returns ``true`` on success. ## Procedure returns ``true`` on success.
try: init(sig, ncrutils.fromHex(data))
init(sig, fromHex(data))
except ValueError:
false
proc init*(t: typedesc[EdPrivateKey], data: openarray[byte]): Result[EdPrivateKey, EdError] = proc init*(t: typedesc[EdPrivateKey],
data: openarray[byte]): Result[EdPrivateKey, EdError] =
## Initialize ED25519 `private key` from raw binary representation ``data`` ## Initialize ED25519 `private key` from raw binary representation ``data``
## and return constructed object. ## and return constructed object.
var res: t var res: t
@ -1813,7 +1810,8 @@ proc init*(t: typedesc[EdPrivateKey], data: openarray[byte]): Result[EdPrivateKe
else: else:
ok(res) ok(res)
proc init*(t: typedesc[EdPublicKey], data: openarray[byte]): Result[EdPublicKey, EdError] = proc init*(t: typedesc[EdPublicKey],
data: openarray[byte]): Result[EdPublicKey, EdError] =
## Initialize ED25519 `public key` from raw binary representation ``data`` ## Initialize ED25519 `public key` from raw binary representation ``data``
## and return constructed object. ## and return constructed object.
var res: t var res: t
@ -1822,7 +1820,8 @@ proc init*(t: typedesc[EdPublicKey], data: openarray[byte]): Result[EdPublicKey,
else: else:
ok(res) ok(res)
proc init*(t: typedesc[EdSignature], data: openarray[byte]): Result[EdSignature, EdError] = proc init*(t: typedesc[EdSignature],
data: openarray[byte]): Result[EdSignature, EdError] =
## Initialize ED25519 `signature` from raw binary representation ``data`` ## Initialize ED25519 `signature` from raw binary representation ``data``
## and return constructed object. ## and return constructed object.
var res: t var res: t
@ -1831,7 +1830,8 @@ proc init*(t: typedesc[EdSignature], data: openarray[byte]): Result[EdSignature,
else: else:
ok(res) ok(res)
proc init*(t: typedesc[EdPrivateKey], data: string): Result[EdPrivateKey, EdError] = proc init*(t: typedesc[EdPrivateKey],
data: string): Result[EdPrivateKey, EdError] =
## Initialize ED25519 `private key` from hexadecimal string representation ## Initialize ED25519 `private key` from hexadecimal string representation
## ``data`` and return constructed object. ## ``data`` and return constructed object.
var res: t var res: t
@ -1840,7 +1840,8 @@ proc init*(t: typedesc[EdPrivateKey], data: string): Result[EdPrivateKey, EdErro
else: else:
ok(res) ok(res)
proc init*(t: typedesc[EdPublicKey], data: string): Result[EdPublicKey, EdError] = proc init*(t: typedesc[EdPublicKey],
data: string): Result[EdPublicKey, EdError] =
## Initialize ED25519 `public key` from hexadecimal string representation ## Initialize ED25519 `public key` from hexadecimal string representation
## ``data`` and return constructed object. ## ``data`` and return constructed object.
var res: t var res: t
@ -1849,7 +1850,8 @@ proc init*(t: typedesc[EdPublicKey], data: string): Result[EdPublicKey, EdError]
else: else:
ok(res) ok(res)
proc init*(t: typedesc[EdSignature], data: string): Result[EdSignature, EdError] = proc init*(t: typedesc[EdSignature],
data: string): Result[EdSignature, EdError] =
## Initialize ED25519 `signature` from hexadecimal string representation ## Initialize ED25519 `signature` from hexadecimal string representation
## ``data`` and return constructed object. ## ``data`` and return constructed object.
var res: t var res: t

View File

@ -11,9 +11,10 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import stew/[endians2, results] import stew/[endians2, results, ctops]
export results export results
import nimcrypto/utils # We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
type type
Asn1Error* {.pure.} = enum Asn1Error* {.pure.} = enum
@ -122,7 +123,7 @@ proc len*[T: Asn1Buffer|Asn1Composite](abc: T): int {.inline.} =
len(abc.buffer) - abc.offset len(abc.buffer) - abc.offset
proc len*(field: Asn1Field): int {.inline.} = proc len*(field: Asn1Field): int {.inline.} =
result = field.length field.length
template getPtr*(field: untyped): pointer = template getPtr*(field: untyped): pointer =
cast[pointer](unsafeAddr field.buffer[field.offset]) cast[pointer](unsafeAddr field.buffer[field.offset])
@ -153,30 +154,32 @@ proc code*(tag: Asn1Tag): byte {.inline.} =
of Asn1Tag.Context: of Asn1Tag.Context:
0xA0'u8 0xA0'u8
proc asn1EncodeLength*(dest: var openarray[byte], length: int64): int = proc asn1EncodeLength*(dest: var openarray[byte], length: uint64): int =
## Encode ASN.1 DER length part of TLV triple and return number of bytes ## Encode ASN.1 DER length part of TLV triple and return number of bytes
## (octets) used. ## (octets) used.
## ##
## If length of ``dest`` is less then number of required bytes to encode ## If length of ``dest`` is less then number of required bytes to encode
## ``length`` value, then result of encoding will not be stored in ``dest`` ## ``length`` value, then result of encoding WILL NOT BE stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
if length < 0x80: if length < 0x80'u64:
if len(dest) >= 1: if len(dest) >= 1:
dest[0] = cast[byte](length) dest[0] = byte(length and 0x7F'u64)
result = 1 1
else: else:
result = 0 var res = 1'u64
var z = length var z = length
while z != 0: while z != 0:
inc(result) inc(res)
z = z shr 8 z = z shr 8
if len(dest) >= result + 1: if uint64(len(dest)) >= res:
dest[0] = cast[byte](0x80 + result) dest[0] = byte((0x80'u64 + (res - 1'u64)) and 0xFF)
var o = 1 var o = 1
for j in countdown(result - 1, 0): for j in countdown(res - 2, 0):
dest[o] = cast[byte](length shr (j shl 3)) dest[o] = byte((length shr (j shl 3)) and 0xFF'u64)
inc(o) inc(o)
inc(result) # Because our `length` argument is `uint64`, `res` could not be bigger
# then 9, so it is safe to convert it to `int`.
int(res)
proc asn1EncodeInteger*(dest: var openarray[byte], proc asn1EncodeInteger*(dest: var openarray[byte],
value: openarray[byte]): int = value: openarray[byte]): int =
@ -184,35 +187,46 @@ proc asn1EncodeInteger*(dest: var openarray[byte],
## and return number of bytes (octets) used. ## and return number of bytes (octets) used.
## ##
## If length of ``dest`` is less then number of required bytes to encode ## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding WILL NOT BE stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
var o = 0
var lenlen = 0 var lenlen = 0
for i in 0..<len(value):
if value[o] != 0x00: let offset =
break block:
inc(o) var o = 0
if len(value) > 0: for i in 0 ..< len(value):
if o == len(value): if value[o] != 0x00:
dec(o) break
if value[o] >= 0x80'u8: inc(o)
lenlen = asn1EncodeLength(buffer, len(value) - o + 1) if o < len(value):
result = 1 + lenlen + 1 + (len(value) - o) o
else:
o - 1
let destlen =
if len(value) > 0:
if value[offset] >= 0x80'u8:
lenlen = asn1EncodeLength(buffer, uint64(len(value) - offset + 1))
1 + lenlen + 1 + (len(value) - offset)
else:
lenlen = asn1EncodeLength(buffer, uint64(len(value) - offset))
1 + lenlen + (len(value) - offset)
else: else:
lenlen = asn1EncodeLength(buffer, len(value) - o) 2
result = 1 + lenlen + (len(value) - o)
else: if len(dest) >= destlen:
result = 2 var shift = 1
if len(dest) >= result:
var s = 1
dest[0] = Asn1Tag.Integer.code() dest[0] = Asn1Tag.Integer.code()
copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1], addr buffer[0], lenlen)
if value[o] >= 0x80'u8: # If ``destlen > 2`` it means that ``len(value) > 0`` too.
dest[1 + lenlen] = 0x00'u8 if destlen > 2:
s = 2 if value[offset] >= 0x80'u8:
if len(value) > 0: dest[1 + lenlen] = 0x00'u8
copyMem(addr dest[s + lenlen], unsafeAddr value[o], len(value) - o) shift = 2
copyMem(addr dest[shift + lenlen], unsafeAddr value[offset],
len(value) - offset)
destlen
proc asn1EncodeInteger*[T: SomeUnsignedInt](dest: var openarray[byte], proc asn1EncodeInteger*[T: SomeUnsignedInt](dest: var openarray[byte],
value: T): int = value: T): int =
@ -231,11 +245,12 @@ proc asn1EncodeBoolean*(dest: var openarray[byte], value: bool): int =
## If length of ``dest`` is less then number of required bytes to encode ## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
result = 3 let res = 3
if len(dest) >= result: if len(dest) >= res:
dest[0] = Asn1Tag.Boolean.code() dest[0] = Asn1Tag.Boolean.code()
dest[1] = 0x01'u8 dest[1] = 0x01'u8
dest[2] = if value: 0xFF'u8 else: 0x00'u8 dest[2] = if value: 0xFF'u8 else: 0x00'u8
res
proc asn1EncodeNull*(dest: var openarray[byte]): int = proc asn1EncodeNull*(dest: var openarray[byte]): int =
## Encode ASN.1 DER `NULL` and return number of bytes (octets) used. ## Encode ASN.1 DER `NULL` and return number of bytes (octets) used.
@ -243,13 +258,14 @@ proc asn1EncodeNull*(dest: var openarray[byte]): int =
## If length of ``dest`` is less then number of required bytes to encode ## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
result = 2 let res = 2
if len(dest) >= result: if len(dest) >= res:
dest[0] = Asn1Tag.Null.code() dest[0] = Asn1Tag.Null.code()
dest[1] = 0x00'u8 dest[1] = 0x00'u8
res
proc asn1EncodeOctetString*(dest: var openarray[byte], proc asn1EncodeOctetString*(dest: var openarray[byte],
value: openarray[byte]): int = value: openarray[byte]): int =
## Encode array of bytes as ASN.1 DER `OCTET STRING` and return number of ## Encode array of bytes as ASN.1 DER `OCTET STRING` and return number of
## bytes (octets) used. ## bytes (octets) used.
## ##
@ -257,38 +273,50 @@ proc asn1EncodeOctetString*(dest: var openarray[byte],
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
var lenlen = asn1EncodeLength(buffer, len(value)) let lenlen = asn1EncodeLength(buffer, uint64(len(value)))
result = 1 + lenlen + len(value) let res = 1 + lenlen + len(value)
if len(dest) >= result: if len(dest) >= res:
dest[0] = Asn1Tag.OctetString.code() dest[0] = Asn1Tag.OctetString.code()
copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1], addr buffer[0], lenlen)
if len(value) > 0: if len(value) > 0:
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc asn1EncodeBitString*(dest: var openarray[byte], proc asn1EncodeBitString*(dest: var openarray[byte],
value: openarray[byte], bits = 0): int = value: openarray[byte], bits = 0): int =
## Encode array of bytes as ASN.1 DER `BIT STRING` and return number of bytes ## Encode array of bytes as ASN.1 DER `BIT STRING` and return number of bytes
## (octets) used. ## (octets) used.
## ##
## ``bits`` number of used bits in ``value``. If ``bits == 0``, all the bits ## ``bits`` number of unused bits in ``value``. If ``bits == 0``, all the bits
## from ``value`` are used, if ``bits != 0`` only number of ``bits`` will be ## from ``value`` will be used.
## used.
## ##
## If length of ``dest`` is less then number of required bytes to encode ## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
var lenlen = asn1EncodeLength(buffer, len(value) + 1) let bitlen =
var lbits = 0 if bits != 0:
if bits != 0: (len(value) shl 3) - bits
lbits = len(value) shl 3 - bits else:
result = 1 + lenlen + 1 + len(value) (len(value) shl 3)
if len(dest) >= result:
# Number of bytes used
let bytelen = (bitlen + 7) shr 3
# Number of unused bits
let unused = (8 - (bitlen and 7)) and 7
let mask = not((1'u8 shl unused) - 1'u8)
var lenlen = asn1EncodeLength(buffer, uint64(bytelen + 1))
let res = 1 + lenlen + 1 + len(value)
if len(dest) >= res:
dest[0] = Asn1Tag.BitString.code() dest[0] = Asn1Tag.BitString.code()
copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1], addr buffer[0], lenlen)
dest[1 + lenlen] = cast[byte](lbits) dest[1 + lenlen] = byte(unused)
if len(value) > 0: if bytelen > 0:
copyMem(addr dest[2 + lenlen], unsafeAddr value[0], len(value)) let lastbyte = value[bytelen - 1]
copyMem(addr dest[2 + lenlen], unsafeAddr value[0], bytelen)
# Set unused bits to zero
dest[2 + lenlen + bytelen - 1] = lastbyte and mask
res
proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openarray[byte], proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openarray[byte],
value: T): int = value: T): int =
@ -296,53 +324,48 @@ proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openarray[byte],
if value <= cast[T](0x7F): if value <= cast[T](0x7F):
if len(dest) >= 1: if len(dest) >= 1:
dest[0] = cast[byte](value) dest[0] = cast[byte](value)
result = 1 1
else: else:
var s = 0 var s = 0
var res = 0
while v != 0: while v != 0:
v = v shr 7 v = v shr 7
s += 7 s += 7
inc(result) inc(res)
if len(dest) >= result: if len(dest) >= res:
var k = 0 var k = 0
while s != 0: while s != 0:
s -= 7 s -= 7
dest[k] = cast[byte](((value shr s) and cast[T](0x7F)) or cast[T](0x80)) dest[k] = cast[byte](((value shr s) and cast[T](0x7F)) or cast[T](0x80))
inc(k) inc(k)
dest[k - 1] = dest[k - 1] and 0x7F'u8 dest[k - 1] = dest[k - 1] and 0x7F'u8
res
proc asn1EncodeOid*(dest: var openarray[byte], value: openarray[int]): int = proc asn1EncodeOid*(dest: var openarray[byte], value: openarray[int]): int =
## Encode array of integers ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and ## Encode array of integers ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and
## return number of bytes (octets) used. ## return number of bytes (octets) used.
## ##
## OBJECT IDENTIFIER requirements for ``value`` elements:
## * len(value) >= 2
## * value[0] >= 1 and value[0] < 2
## * value[1] >= 1 and value[1] < 39
##
## If length of ``dest`` is less then number of required bytes to encode ## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
result = 1 var res = 1
doAssert(len(value) >= 2)
doAssert(value[0] >= 1 and value[0] < 2)
doAssert(value[1] >= 1 and value[1] <= 39)
var oidlen = 1 var oidlen = 1
for i in 2..<len(value): for i in 2..<len(value):
oidlen += asn1EncodeTag(buffer, cast[uint64](value[i])) oidlen += asn1EncodeTag(buffer, cast[uint64](value[i]))
result += asn1EncodeLength(buffer, oidlen) res += asn1EncodeLength(buffer, uint64(oidlen))
result += oidlen res += oidlen
if len(dest) >= result: if len(dest) >= res:
let last = dest.high let last = dest.high
var offset = 1 var offset = 1
dest[0] = Asn1Tag.Oid.code() dest[0] = Asn1Tag.Oid.code()
offset += asn1EncodeLength(dest.toOpenArray(offset, last), oidlen) offset += asn1EncodeLength(dest.toOpenArray(offset, last), uint64(oidlen))
dest[offset] = cast[byte](value[0] * 40 + value[1]) dest[offset] = cast[byte](value[0] * 40 + value[1])
offset += 1 offset += 1
for i in 2..<len(value): for i in 2..<len(value):
offset += asn1EncodeTag(dest.toOpenArray(offset, last), offset += asn1EncodeTag(dest.toOpenArray(offset, last),
cast[uint64](value[i])) cast[uint64](value[i]))
res
proc asn1EncodeOid*(dest: var openarray[byte], value: openarray[byte]): int = proc asn1EncodeOid*(dest: var openarray[byte], value: openarray[byte]): int =
## Encode array of bytes ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and return ## Encode array of bytes ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and return
@ -355,12 +378,13 @@ proc asn1EncodeOid*(dest: var openarray[byte], value: openarray[byte]): int =
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
var lenlen = asn1EncodeLength(buffer, len(value)) let lenlen = asn1EncodeLength(buffer, uint64(len(value)))
result = 1 + lenlen + len(value) let res = 1 + lenlen + len(value)
if len(dest) >= result: if len(dest) >= res:
dest[0] = Asn1Tag.Oid.code() dest[0] = Asn1Tag.Oid.code()
copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1], addr buffer[0], lenlen)
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc asn1EncodeSequence*(dest: var openarray[byte], proc asn1EncodeSequence*(dest: var openarray[byte],
value: openarray[byte]): int = value: openarray[byte]): int =
@ -371,12 +395,13 @@ proc asn1EncodeSequence*(dest: var openarray[byte],
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
var lenlen = asn1EncodeLength(buffer, len(value)) let lenlen = asn1EncodeLength(buffer, uint64(len(value)))
result = 1 + lenlen + len(value) let res = 1 + lenlen + len(value)
if len(dest) >= result: if len(dest) >= res:
dest[0] = Asn1Tag.Sequence.code() dest[0] = Asn1Tag.Sequence.code()
copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1], addr buffer[0], lenlen)
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc asn1EncodeComposite*(dest: var openarray[byte], proc asn1EncodeComposite*(dest: var openarray[byte],
value: Asn1Composite): int = value: Asn1Composite): int =
@ -386,29 +411,34 @@ proc asn1EncodeComposite*(dest: var openarray[byte],
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
var lenlen = asn1EncodeLength(buffer, len(value.buffer)) let lenlen = asn1EncodeLength(buffer, uint64(len(value.buffer)))
result = 1 + lenlen + len(value.buffer) let res = 1 + lenlen + len(value.buffer)
if len(dest) >= result: if len(dest) >= res:
dest[0] = value.tag.code() dest[0] = value.tag.code()
copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1], addr buffer[0], lenlen)
copyMem(addr dest[1 + lenlen], unsafeAddr value.buffer[0], copyMem(addr dest[1 + lenlen], unsafeAddr value.buffer[0],
len(value.buffer)) len(value.buffer))
res
proc asn1EncodeContextTag*(dest: var openarray[byte], value: openarray[byte], proc asn1EncodeContextTag*(dest: var openarray[byte], value: openarray[byte],
tag: int): int = tag: int): int =
## Encode ASN.1 DER `CONTEXT SPECIFIC TAG` ``tag`` for value ``value`` and ## Encode ASN.1 DER `CONTEXT SPECIFIC TAG` ``tag`` for value ``value`` and
## return number of bytes (octets) used. ## return number of bytes (octets) used.
## ##
## Note: Only values in [0, 15] range can be used as context tag ``tag``
## values.
##
## If length of ``dest`` is less then number of required bytes to encode ## If length of ``dest`` is less then number of required bytes to encode
## ``value``, then result of encoding will not be stored in ``dest`` ## ``value``, then result of encoding will not be stored in ``dest``
## but number of bytes (octets) required will be returned. ## but number of bytes (octets) required will be returned.
var buffer: array[16, byte] var buffer: array[16, byte]
var lenlen = asn1EncodeLength(buffer, len(value)) let lenlen = asn1EncodeLength(buffer, uint64(len(value)))
result = 1 + lenlen + len(value) let res = 1 + lenlen + len(value)
if len(dest) >= result: if len(dest) >= res:
dest[0] = 0xA0'u8 or (cast[byte](tag) and 0x0F) dest[0] = 0xA0'u8 or (byte(tag and 0xFF) and 0x0F'u8)
copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1], addr buffer[0], lenlen)
copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value))
res
proc getLength(ab: var Asn1Buffer): Asn1Result[uint64] = proc getLength(ab: var Asn1Buffer): Asn1Result[uint64] =
## Decode length part of ASN.1 TLV triplet. ## Decode length part of ASN.1 TLV triplet.
@ -457,197 +487,300 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] =
field: Asn1Field field: Asn1Field
tag, ttag, offset: int tag, ttag, offset: int
length, tlength: uint64 length, tlength: uint64
klass: Asn1Class aclass: Asn1Class
inclass: bool inclass: bool
inclass = false inclass = false
while true: while true:
offset = ab.offset offset = ab.offset
klass = ? ab.getTag(tag) aclass = ? ab.getTag(tag)
if klass == Asn1Class.ContextSpecific: case aclass
of Asn1Class.ContextSpecific:
if inclass: if inclass:
return err(Asn1Error.Incorrect) return err(Asn1Error.Incorrect)
else:
inclass = true inclass = true
ttag = tag ttag = tag
tlength = ? ab.getLength() tlength = ? ab.getLength()
of Asn1Class.Universal:
elif klass == Asn1Class.Universal:
length = ? ab.getLength() length = ? ab.getLength()
if inclass: if inclass:
if length >= tlength: if length >= tlength:
return err(Asn1Error.Incorrect) return err(Asn1Error.Incorrect)
if cast[byte](tag) == Asn1Tag.Boolean.code(): case byte(tag)
of Asn1Tag.Boolean.code():
# BOOLEAN # BOOLEAN
if length != 1: if length != 1:
return err(Asn1Error.Incorrect) return err(Asn1Error.Incorrect)
if not ab.isEnough(cast[int](length)):
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete) return err(Asn1Error.Incomplete)
let b = ab.buffer[ab.offset] let b = ab.buffer[ab.offset]
if b != 0xFF'u8 and b != 0x00'u8: if b != 0xFF'u8 and b != 0x00'u8:
return err(Asn1Error.Incorrect) return err(Asn1Error.Incorrect)
field = Asn1Field(kind: Asn1Tag.Boolean, klass: klass, field = Asn1Field(kind: Asn1Tag.Boolean, klass: aclass,
index: ttag, offset: cast[int](ab.offset), index: ttag, offset: int(ab.offset),
length: 1) length: 1)
shallowCopy(field.buffer, ab.buffer) shallowCopy(field.buffer, ab.buffer)
field.vbool = (b == 0xFF'u8) field.vbool = (b == 0xFF'u8)
ab.offset += 1 ab.offset += 1
return ok(field) return ok(field)
elif cast[byte](tag) == Asn1Tag.Integer.code():
of Asn1Tag.Integer.code():
# INTEGER # INTEGER
if not ab.isEnough(cast[int](length)): if length == 0:
return err(Asn1Error.Incomplete) return err(Asn1Error.Incorrect)
if ab.buffer[ab.offset] == 0x00'u8:
length -= 1 if not ab.isEnough(int(length)):
ab.offset += 1 return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.Integer, klass: klass,
index: ttag, offset: cast[int](ab.offset), # Count number of leading zeroes
length: cast[int](length)) var zc = 0
shallowCopy(field.buffer, ab.buffer) while (zc < int(length)) and (ab.buffer[ab.offset + zc] == 0x00'u8):
if length <= 8: inc(zc)
for i in 0..<int(length):
field.vint = (field.vint shl 8) or if zc > 1:
cast[uint64](ab.buffer[ab.offset + i]) return err(Asn1Error.Incorrect)
ab.offset += cast[int](length)
return ok(field) if zc == 0:
elif cast[byte](tag) == Asn1Tag.BitString.code(): # Negative or Positive integer
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length))
shallowCopy(field.buffer, ab.buffer)
if (ab.buffer[ab.offset] and 0x80'u8) == 0x80'u8:
# Negative integer
if length <= 8:
# We need this transformation because our field.vint is uint64.
for i in 0 ..< 8:
if i < 8 - int(length):
field.vint = (field.vint shl 8) or 0xFF'u64
else:
let offset = ab.offset + i - (8 - int(length))
field.vint = (field.vint shl 8) or uint64(ab.buffer[offset])
else:
# Positive integer
if length <= 8:
for i in 0 ..< int(length):
field.vint = (field.vint shl 8) or
uint64(ab.buffer[ab.offset + i])
ab.offset += int(length)
return ok(field)
else:
if length == 1:
# Zero value integer
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset),
length: int(length), vint: 0'u64)
shallowCopy(field.buffer, ab.buffer)
ab.offset += int(length)
return ok(field)
else:
# Positive integer with leading zero
field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass,
index: ttag, offset: int(ab.offset) + 1,
length: int(length) - 1)
shallowCopy(field.buffer, ab.buffer)
if length <= 9:
for i in 1 ..< int(length):
field.vint = (field.vint shl 8) or
uint64(ab.buffer[ab.offset + i])
ab.offset += int(length)
return ok(field)
of Asn1Tag.BitString.code():
# BIT STRING # BIT STRING
if not ab.isEnough(cast[int](length)): if length == 0:
# BIT STRING should include `unused` bits field, so length should be
# bigger then 1.
return err(Asn1Error.Incorrect)
elif length == 1:
if ab.buffer[ab.offset] != 0x00'u8:
return err(Asn1Error.Incorrect)
else:
# Zero-length BIT STRING.
field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass,
index: ttag, offset: int(ab.offset + 1),
length: 0, ubits: 0)
shallowCopy(field.buffer, ab.buffer)
ab.offset += int(length)
return ok(field)
else:
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete)
let unused = ab.buffer[ab.offset]
if unused > 0x07'u8:
# Number of unused bits should not be bigger then `7`.
return err(Asn1Error.Incorrect)
let mask = (1'u8 shl int(unused)) - 1'u8
if (ab.buffer[ab.offset + int(length) - 1] and mask) != 0x00'u8:
## All unused bits should be set to `0`.
return err(Asn1Error.Incorrect)
field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass,
index: ttag, offset: int(ab.offset + 1),
length: int(length - 1), ubits: int(unused))
shallowCopy(field.buffer, ab.buffer)
ab.offset += int(length)
return ok(field)
of Asn1Tag.OctetString.code():
# OCTET STRING
if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete) return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.BitString, klass: klass,
index: ttag, offset: cast[int](ab.offset + 1), field = Asn1Field(kind: Asn1Tag.OctetString, klass: aclass,
length: cast[int](length - 1)) index: ttag, offset: int(ab.offset),
length: int(length))
shallowCopy(field.buffer, ab.buffer) shallowCopy(field.buffer, ab.buffer)
field.ubits = cast[int](((length - 1) shl 3) - ab.buffer[ab.offset]) ab.offset += int(length)
ab.offset += cast[int](length)
return ok(field) return ok(field)
elif cast[byte](tag) == Asn1Tag.OctetString.code():
# OCT STRING of Asn1Tag.Null.code():
if not ab.isEnough(cast[int](length)):
return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.OctetString, klass: klass,
index: ttag, offset: cast[int](ab.offset),
length: cast[int](length))
shallowCopy(field.buffer, ab.buffer)
ab.offset += cast[int](length)
return ok(field)
elif cast[byte](tag) == Asn1Tag.Null.code():
# NULL # NULL
if length != 0: if length != 0:
return err(Asn1Error.Incorrect) return err(Asn1Error.Incorrect)
field = Asn1Field(kind: Asn1Tag.Null, klass: klass,
index: ttag, offset: cast[int](ab.offset), field = Asn1Field(kind: Asn1Tag.Null, klass: aclass, index: ttag,
length: 0) offset: int(ab.offset), length: 0)
shallowCopy(field.buffer, ab.buffer) shallowCopy(field.buffer, ab.buffer)
ab.offset += cast[int](length) ab.offset += int(length)
return ok(field) return ok(field)
elif cast[byte](tag) == Asn1Tag.Oid.code():
of Asn1Tag.Oid.code():
# OID # OID
if not ab.isEnough(cast[int](length)): if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete) return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.Oid, klass: klass,
index: ttag, offset: cast[int](ab.offset), field = Asn1Field(kind: Asn1Tag.Oid, klass: aclass,
length: cast[int](length)) index: ttag, offset: int(ab.offset),
length: int(length))
shallowCopy(field.buffer, ab.buffer) shallowCopy(field.buffer, ab.buffer)
ab.offset += cast[int](length) ab.offset += int(length)
return ok(field) return ok(field)
elif cast[byte](tag) == Asn1Tag.Sequence.code():
of Asn1Tag.Sequence.code():
# SEQUENCE # SEQUENCE
if not ab.isEnough(cast[int](length)): if not ab.isEnough(int(length)):
return err(Asn1Error.Incomplete) return err(Asn1Error.Incomplete)
field = Asn1Field(kind: Asn1Tag.Sequence, klass: klass,
index: ttag, offset: cast[int](ab.offset), field = Asn1Field(kind: Asn1Tag.Sequence, klass: aclass,
length: cast[int](length)) index: ttag, offset: int(ab.offset),
length: int(length))
shallowCopy(field.buffer, ab.buffer) shallowCopy(field.buffer, ab.buffer)
ab.offset += cast[int](length) ab.offset += int(length)
return ok(field) return ok(field)
else: else:
return err(Asn1Error.NoSupport) return err(Asn1Error.NoSupport)
inclass = false inclass = false
ttag = 0 ttag = 0
else: else:
return err(Asn1Error.NoSupport) return err(Asn1Error.NoSupport)
proc getBuffer*(field: Asn1Field): Asn1Buffer = proc getBuffer*(field: Asn1Field): Asn1Buffer {.inline.} =
## Return ``field`` as Asn1Buffer to enter composite types. ## Return ``field`` as Asn1Buffer to enter composite types.
shallowCopy(result.buffer, field.buffer) Asn1Buffer(buffer: field.buffer, offset: field.offset, length: field.length)
result.offset = field.offset
result.length = field.length
proc `==`*(field: Asn1Field, data: openarray[byte]): bool = proc `==`*(field: Asn1Field, data: openarray[byte]): bool =
## Compares field ``field`` data with ``data`` and returns ``true`` if both ## Compares field ``field`` data with ``data`` and returns ``true`` if both
## buffers are equal. ## buffers are equal.
let length = len(field.buffer) let length = len(field.buffer)
if length > 0: if length == 0 and len(data) == 0:
if field.length == len(data): true
result = equalMem(unsafeAddr field.buffer[field.offset], else:
unsafeAddr data[0], field.length) if length > 0:
if field.length == len(data):
CT.isEqual(
field.buffer.toOpenArray(field.offset,
field.offset + field.length - 1),
data.toOpenArray(0, field.length - 1))
else:
false
else:
false
proc init*(t: typedesc[Asn1Buffer], data: openarray[byte]): Asn1Buffer = proc init*(t: typedesc[Asn1Buffer], data: openarray[byte]): Asn1Buffer =
## Initialize ``Asn1Buffer`` from array of bytes ``data``. ## Initialize ``Asn1Buffer`` from array of bytes ``data``.
result.buffer = @data Asn1Buffer(buffer: @data)
proc init*(t: typedesc[Asn1Buffer], data: string): Asn1Buffer = proc init*(t: typedesc[Asn1Buffer], data: string): Asn1Buffer =
## Initialize ``Asn1Buffer`` from hexadecimal string ``data``. ## Initialize ``Asn1Buffer`` from hexadecimal string ``data``.
result.buffer = fromHex(data) Asn1Buffer(buffer: ncrutils.fromHex(data))
proc init*(t: typedesc[Asn1Buffer]): Asn1Buffer = proc init*(t: typedesc[Asn1Buffer]): Asn1Buffer =
## Initialize empty ``Asn1Buffer``. ## Initialize empty ``Asn1Buffer``.
result.buffer = newSeq[byte]() Asn1Buffer(buffer: newSeq[byte]())
proc init*(t: typedesc[Asn1Composite], tag: Asn1Tag): Asn1Composite = proc init*(t: typedesc[Asn1Composite], tag: Asn1Tag): Asn1Composite =
## Initialize ``Asn1Composite`` with tag ``tag``. ## Initialize ``Asn1Composite`` with tag ``tag``.
result.tag = tag Asn1Composite(tag: tag, buffer: newSeq[byte]())
result.buffer = newSeq[byte]()
proc init*(t: typedesc[Asn1Composite], idx: int): Asn1Composite = proc init*(t: typedesc[Asn1Composite], idx: int): Asn1Composite =
## Initialize ``Asn1Composite`` with tag context-specific id ``id``. ## Initialize ``Asn1Composite`` with tag context-specific id ``id``.
result.tag = Asn1Tag.Context Asn1Composite(tag: Asn1Tag.Context, idx: idx, buffer: newSeq[byte]())
result.idx = idx
result.buffer = newSeq[byte]()
proc `$`*(buffer: Asn1Buffer): string = proc `$`*(buffer: Asn1Buffer): string =
## Return string representation of ``buffer``. ## Return string representation of ``buffer``.
result = toHex(buffer.toOpenArray()) ncrutils.toHex(buffer.toOpenArray())
proc `$`*(field: Asn1Field): string = proc `$`*(field: Asn1Field): string =
## Return string representation of ``field``. ## Return string representation of ``field``.
result = "[" var res = "["
result.add($field.kind) res.add($field.kind)
result.add("]") res.add("]")
if field.kind == Asn1Tag.NoSupport: case field.kind
result.add(" ") of Asn1Tag.Boolean:
result.add(toHex(field.toOpenArray())) res.add(" ")
elif field.kind == Asn1Tag.Boolean: res.add($field.vbool)
result.add(" ") res
result.add($field.vbool) of Asn1Tag.Integer:
elif field.kind == Asn1Tag.Integer: res.add(" ")
result.add(" ")
if field.length <= 8: if field.length <= 8:
result.add($field.vint) res.add($field.vint)
else: else:
result.add(toHex(field.toOpenArray())) res.add(ncrutils.toHex(field.toOpenArray()))
elif field.kind == Asn1Tag.BitString: res
result.add(" ") of Asn1Tag.BitString:
result.add("(") res.add(" ")
result.add($field.ubits) res.add("(")
result.add(" bits) ") res.add($field.ubits)
result.add(toHex(field.toOpenArray())) res.add(" bits) ")
elif field.kind == Asn1Tag.OctetString: res.add(ncrutils.toHex(field.toOpenArray()))
result.add(" ") res
result.add(toHex(field.toOpenArray())) of Asn1Tag.OctetString:
elif field.kind == Asn1Tag.Null: res.add(" ")
result.add(" NULL") res.add(ncrutils.toHex(field.toOpenArray()))
elif field.kind == Asn1Tag.Oid: res
result.add(" ") of Asn1Tag.Null:
result.add(toHex(field.toOpenArray())) res.add(" NULL")
elif field.kind == Asn1Tag.Sequence: res
result.add(" ") of Asn1Tag.Oid:
result.add(toHex(field.toOpenArray())) res.add(" ")
res.add(ncrutils.toHex(field.toOpenArray()))
res
of Asn1Tag.Sequence:
res.add(" ")
res.add(ncrutils.toHex(field.toOpenArray()))
res
of Asn1Tag.Context:
res.add(" ")
res.add(ncrutils.toHex(field.toOpenArray()))
res
else:
res.add(" ")
res.add(ncrutils.toHex(field.toOpenArray()))
res
proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, tag: Asn1Tag) = proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, tag: Asn1Tag) =
## Write empty value to buffer or composite with ``tag``. ## Write empty value to buffer or composite with ``tag``.
@ -655,7 +788,7 @@ proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, tag: Asn1Tag) =
## This procedure must be used to write `NULL`, `0` or empty `BIT STRING`, ## This procedure must be used to write `NULL`, `0` or empty `BIT STRING`,
## `OCTET STRING` types. ## `OCTET STRING` types.
doAssert(tag in {Asn1Tag.Null, Asn1Tag.Integer, Asn1Tag.BitString, doAssert(tag in {Asn1Tag.Null, Asn1Tag.Integer, Asn1Tag.BitString,
Asn1Tag.OctetString}) Asn1Tag.OctetString})
var length: int var length: int
if tag == Asn1Tag.Null: if tag == Asn1Tag.Null:
length = asn1EncodeNull(abc.toOpenArray()) length = asn1EncodeNull(abc.toOpenArray())

View File

@ -14,13 +14,13 @@
## Copyright(C) 2018 Thomas Pornin <pornin@bolet.org>. ## Copyright(C) 2018 Thomas Pornin <pornin@bolet.org>.
{.push raises: Defect.} {.push raises: Defect.}
import nimcrypto/utils
import bearssl import bearssl
import minasn1 import minasn1
export Asn1Error
import stew/[results, ctops] import stew/[results, ctops]
export results # We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures.
import nimcrypto/utils as ncrutils
export Asn1Error, results
const const
DefaultPublicExponent* = 65537'u32 DefaultPublicExponent* = 65537'u32
@ -574,14 +574,16 @@ proc init*(sig: var RsaSignature, data: openarray[byte]): Result[void, Asn1Error
else: else:
err(Asn1Error.Incorrect) err(Asn1Error.Incorrect)
proc init*[T: RsaPKI](sospk: var T, data: string): Result[void, Asn1Error] {.inline.} = proc init*[T: RsaPKI](sospk: var T,
data: string): Result[void, Asn1Error] {.inline.} =
## Initialize EC `private key`, `public key` or `scalar` ``sospk`` from ## Initialize EC `private key`, `public key` or `scalar` ``sospk`` from
## hexadecimal string representation ``data``. ## hexadecimal string representation ``data``.
## ##
## Procedure returns ``Result[void, Asn1Status]``. ## Procedure returns ``Result[void, Asn1Status]``.
sospk.init(fromHex(data)) sospk.init(ncrutils.fromHex(data))
proc init*(t: typedesc[RsaPrivateKey], data: openarray[byte]): RsaResult[RsaPrivateKey] = proc init*(t: typedesc[RsaPrivateKey],
data: openarray[byte]): RsaResult[RsaPrivateKey] =
## Initialize RSA private key from ASN.1 DER binary representation ``data`` ## Initialize RSA private key from ASN.1 DER binary representation ``data``
## and return constructed object. ## and return constructed object.
var res: RsaPrivateKey var res: RsaPrivateKey
@ -590,7 +592,8 @@ proc init*(t: typedesc[RsaPrivateKey], data: openarray[byte]): RsaResult[RsaPriv
else: else:
ok(res) ok(res)
proc init*(t: typedesc[RsaPublicKey], data: openarray[byte]): RsaResult[RsaPublicKey] = proc init*(t: typedesc[RsaPublicKey],
data: openarray[byte]): RsaResult[RsaPublicKey] =
## Initialize RSA public key from ASN.1 DER binary representation ``data`` ## Initialize RSA public key from ASN.1 DER binary representation ``data``
## and return constructed object. ## and return constructed object.
var res: RsaPublicKey var res: RsaPublicKey
@ -599,7 +602,8 @@ proc init*(t: typedesc[RsaPublicKey], data: openarray[byte]): RsaResult[RsaPubli
else: else:
ok(res) ok(res)
proc init*(t: typedesc[RsaSignature], data: openarray[byte]): RsaResult[RsaSignature] = proc init*(t: typedesc[RsaSignature],
data: openarray[byte]): RsaResult[RsaSignature] =
## Initialize RSA signature from raw binary representation ``data`` and ## Initialize RSA signature from raw binary representation ``data`` and
## return constructed object. ## return constructed object.
var res: RsaSignature var res: RsaSignature
@ -611,7 +615,7 @@ proc init*(t: typedesc[RsaSignature], data: openarray[byte]): RsaResult[RsaSigna
proc init*[T: RsaPKI](t: typedesc[T], data: string): T {.inline.} = proc init*[T: RsaPKI](t: typedesc[T], data: string): T {.inline.} =
## Initialize RSA `private key`, `public key` or `signature` from hexadecimal ## Initialize RSA `private key`, `public key` or `signature` from hexadecimal
## string representation ``data`` and return constructed object. ## string representation ``data`` and return constructed object.
result = t.init(fromHex(data)) result = t.init(ncrutils.fromHex(data))
proc `$`*(key: RsaPrivateKey): string = proc `$`*(key: RsaPrivateKey): string =
## Return string representation of RSA private key. ## Return string representation of RSA private key.
@ -622,21 +626,24 @@ proc `$`*(key: RsaPrivateKey): string =
result.add($key.seck.nBitlen) result.add($key.seck.nBitlen)
result.add(" bits)\n") result.add(" bits)\n")
result.add("p = ") result.add("p = ")
result.add(toHex(getArray(key.buffer, key.seck.p, key.seck.plen))) result.add(ncrutils.toHex(getArray(key.buffer, key.seck.p, key.seck.plen)))
result.add("\nq = ") result.add("\nq = ")
result.add(toHex(getArray(key.buffer, key.seck.q, key.seck.qlen))) result.add(ncrutils.toHex(getArray(key.buffer, key.seck.q, key.seck.qlen)))
result.add("\ndp = ") result.add("\ndp = ")
result.add(toHex(getArray(key.buffer, key.seck.dp, key.seck.dplen))) result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dp,
key.seck.dplen)))
result.add("\ndq = ") result.add("\ndq = ")
result.add(toHex(getArray(key.buffer, key.seck.dq, key.seck.dqlen))) result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dq,
key.seck.dqlen)))
result.add("\niq = ") result.add("\niq = ")
result.add(toHex(getArray(key.buffer, key.seck.iq, key.seck.iqlen))) result.add(ncrutils.toHex(getArray(key.buffer, key.seck.iq,
key.seck.iqlen)))
result.add("\npre = ") result.add("\npre = ")
result.add(toHex(getArray(key.buffer, key.pexp, key.pexplen))) result.add(ncrutils.toHex(getArray(key.buffer, key.pexp, key.pexplen)))
result.add("\nm = ") result.add("\nm = ")
result.add(toHex(getArray(key.buffer, key.pubk.n, key.pubk.nlen))) result.add(ncrutils.toHex(getArray(key.buffer, key.pubk.n, key.pubk.nlen)))
result.add("\npue = ") result.add("\npue = ")
result.add(toHex(getArray(key.buffer, key.pubk.e, key.pubk.elen))) result.add(ncrutils.toHex(getArray(key.buffer, key.pubk.e, key.pubk.elen)))
result.add("\n") result.add("\n")
proc `$`*(key: RsaPublicKey): string = proc `$`*(key: RsaPublicKey): string =
@ -648,9 +655,9 @@ proc `$`*(key: RsaPublicKey): string =
result = "RSA key (" result = "RSA key ("
result.add($nbitlen) result.add($nbitlen)
result.add(" bits)\nn = ") result.add(" bits)\nn = ")
result.add(toHex(getArray(key.buffer, key.key.n, key.key.nlen))) result.add(ncrutils.toHex(getArray(key.buffer, key.key.n, key.key.nlen)))
result.add("\ne = ") result.add("\ne = ")
result.add(toHex(getArray(key.buffer, key.key.e, key.key.elen))) result.add(ncrutils.toHex(getArray(key.buffer, key.key.e, key.key.elen)))
result.add("\n") result.add("\n")
proc `$`*(sig: RsaSignature): string = proc `$`*(sig: RsaSignature): string =
@ -659,7 +666,7 @@ proc `$`*(sig: RsaSignature): string =
result = "Empty or uninitialized RSA signature" result = "Empty or uninitialized RSA signature"
else: else:
result = "RSA signature (" result = "RSA signature ("
result.add(toHex(sig.buffer)) result.add(ncrutils.toHex(sig.buffer))
result.add(")") result.add(")")
proc `==`*(a, b: RsaPrivateKey): bool = proc `==`*(a, b: RsaPrivateKey): bool =

View File

@ -138,12 +138,9 @@ proc closeRemote*(s: LPChannel) {.async.} =
trace "got EOF, closing channel" trace "got EOF, closing channel"
try: try:
await s.drainBuffer() await s.drainBuffer()
s.isEof = true # set EOF immediately to prevent further reads s.isEof = true # set EOF immediately to prevent further reads
await s.close() # close local end # close parent bufferstream to prevent further reads
await procCall BufferStream(s).close()
# call to avoid leaks
await procCall BufferStream(s).close() # close parent bufferstream
trace "channel closed on EOF" trace "channel closed on EOF"
except CancelledError as exc: except CancelledError as exc:

View File

@ -96,7 +96,7 @@ proc newStreamInternal*(m: Mplex,
proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} = proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} =
## remove the local channel from the internal tables ## remove the local channel from the internal tables
## ##
await chann.closeEvent.wait() await chann.join()
if not isNil(chann): if not isNil(chann):
m.getChannelList(chann.initiator).del(chann.id) m.getChannelList(chann.initiator).del(chann.id)
trace "cleaned up channel", id = chann.id trace "cleaned up channel", id = chann.id

View File

@ -31,14 +31,9 @@ type
method subscribeTopic*(f: FloodSub, method subscribeTopic*(f: FloodSub,
topic: string, topic: string,
subscribe: bool, subscribe: bool,
peerId: string) {.gcsafe, async.} = peerId: PeerID) {.gcsafe, async.} =
await procCall PubSub(f).subscribeTopic(topic, subscribe, peerId) await procCall PubSub(f).subscribeTopic(topic, subscribe, peerId)
let peer = f.peers.getOrDefault(peerId) let peer = f.peers.getOrDefault(peerId)
if peer == nil:
debug "subscribeTopic on a nil peer!", peer = peerId
return
if topic notin f.floodsub: if topic notin f.floodsub:
f.floodsub[topic] = initHashSet[PubSubPeer]() f.floodsub[topic] = initHashSet[PubSubPeer]()
@ -51,16 +46,20 @@ method subscribeTopic*(f: FloodSub,
# unsubscribe the peer from the topic # unsubscribe the peer from the topic
f.floodsub[topic].excl(peer) f.floodsub[topic].excl(peer)
method handleDisconnect*(f: FloodSub, peer: PubSubPeer) = method unsubscribePeer*(f: FloodSub, peer: PeerID) =
## handle peer disconnects ## handle peer disconnects
## ##
procCall PubSub(f).handleDisconnect(peer)
if not(isNil(peer)) and peer.peerInfo notin f.conns: trace "unsubscribing floodsub peer", peer = $peer
for t in toSeq(f.floodsub.keys): let pubSubPeer = f.peers.getOrDefault(peer)
if t in f.floodsub: if pubSubPeer.isNil:
f.floodsub[t].excl(peer) return
for t in toSeq(f.floodsub.keys):
if t in f.floodsub:
f.floodsub[t].excl(pubSubPeer)
procCall PubSub(f).unsubscribePeer(peer)
method rpcHandler*(f: FloodSub, method rpcHandler*(f: FloodSub,
peer: PubSubPeer, peer: PubSubPeer,
@ -77,7 +76,7 @@ method rpcHandler*(f: FloodSub,
if msgId notin f.seen: if msgId notin f.seen:
f.seen.put(msgId) # add the message to the seen cache f.seen.put(msgId) # add the message to the seen cache
if f.verifySignature and not msg.verify(peer.peerInfo): if f.verifySignature and not msg.verify(peer.peerId):
trace "dropping message due to failed signature verification" trace "dropping message due to failed signature verification"
continue continue
@ -102,7 +101,10 @@ method rpcHandler*(f: FloodSub,
trace "exception in message handler", exc = exc.msg trace "exception in message handler", exc = exc.msg
# forward the message to all peers interested in it # forward the message to all peers interested in it
let published = await f.publishHelper(toSendPeers, m.messages, DefaultSendTimeout) let published = await f.broadcast(
toSeq(toSendPeers),
RPCMsg(messages: m.messages),
DefaultSendTimeout)
trace "forwared message to peers", peers = published trace "forwared message to peers", peers = published
@ -118,11 +120,6 @@ method init*(f: FloodSub) =
f.handler = handler f.handler = handler
f.codec = FloodSubCodec f.codec = FloodSubCodec
method subscribePeer*(p: FloodSub,
conn: Connection) =
procCall PubSub(p).subscribePeer(conn)
asyncCheck p.handleConn(conn, FloodSubCodec)
method publish*(f: FloodSub, method publish*(f: FloodSub,
topic: string, topic: string,
data: seq[byte], data: seq[byte],
@ -143,7 +140,10 @@ method publish*(f: FloodSub,
let msg = Message.init(f.peerInfo, data, topic, f.msgSeqno, f.sign) let msg = Message.init(f.peerInfo, data, topic, f.msgSeqno, f.sign)
# start the future but do not wait yet # start the future but do not wait yet
let published = await f.publishHelper(f.floodsub.getOrDefault(topic), @[msg], timeout) let published = await f.broadcast(
toSeq(f.floodsub.getOrDefault(topic)),
RPCMsg(messages: @[msg]),
timeout)
when defined(libp2p_expensive_metrics): when defined(libp2p_expensive_metrics):
libp2p_pubsub_messages_published.inc(labelValues = [topic]) libp2p_pubsub_messages_published.inc(labelValues = [topic])
@ -167,8 +167,6 @@ method unsubscribeAll*(f: FloodSub, topic: string) {.async.} =
method initPubSub*(f: FloodSub) = method initPubSub*(f: FloodSub) =
procCall PubSub(f).initPubSub() procCall PubSub(f).initPubSub()
f.peers = initTable[string, PubSubPeer]()
f.topics = initTable[string, Topic]()
f.floodsub = initTable[string, HashSet[PubSubPeer]]() f.floodsub = initTable[string, HashSet[PubSubPeer]]()
f.seen = newTimedCache[string](2.minutes) f.seen = newTimedCache[string](2.minutes)
f.init() f.init()

View File

@ -404,10 +404,10 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} =
.set(g.mesh.peers(topic).int64, labelValues = [topic]) .set(g.mesh.peers(topic).int64, labelValues = [topic])
# Send changes to peers after table updates to avoid stale state # Send changes to peers after table updates to avoid stale state
for p in grafts: let graft = RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))
await p.sendGraft(@[topic]) let prune = RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))
for p in prunes: discard await g.broadcast(grafts, graft, DefaultSendTimeout)
await p.sendPrune(@[topic]) discard await g.broadcast(prunes, prune, DefaultSendTimeout)
trace "mesh balanced, got peers", peers = g.mesh.peers(topic) trace "mesh balanced, got peers", peers = g.mesh.peers(topic)
@ -426,7 +426,7 @@ proc dropFanoutPeers(g: GossipSub) =
libp2p_gossipsub_peers_per_topic_fanout libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.peers(topic).int64, labelValues = [topic]) .set(g.fanout.peers(topic).int64, labelValues = [topic])
proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} = proc getGossipPeers(g: GossipSub): Table[PubSubPeer, ControlMessage] {.gcsafe.} =
## gossip iHave messages to peers ## gossip iHave messages to peers
## ##
@ -458,10 +458,10 @@ proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} =
if peer in gossipPeers: if peer in gossipPeers:
continue continue
if peer.id notin result: if peer notin result:
result[peer.id] = controlMsg result[peer] = controlMsg
result[peer.id].ihave.add(ihave) result[peer].ihave.add(ihave)
func `/`(a, b: Duration): float64 = func `/`(a, b: Duration): float64 =
let let
@ -582,8 +582,11 @@ proc heartbeat(g: GossipSub) {.async.} =
let peers = g.getGossipPeers() let peers = g.getGossipPeers()
var sent: seq[Future[void]] var sent: seq[Future[void]]
for peer, control in peers: for peer, control in peers:
g.peers.withValue(peer, pubsubPeer) do: g.peers.withValue(peer.peerId, pubsubPeer) do:
sent &= pubsubPeer[].send(RPCMsg(control: some(control))) sent &= g.send(
pubsubPeer[],
RPCMsg(control: some(control)),
DefaultSendTimeout)
checkFutures(await allFinished(sent)) checkFutures(await allFinished(sent))
g.mcache.shift() # shift the cache g.mcache.shift() # shift the cache
@ -599,35 +602,37 @@ proc heartbeat(g: GossipSub) {.async.} =
await sleepAsync(GossipSubHeartbeatInterval) await sleepAsync(GossipSubHeartbeatInterval)
method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = method unsubscribePeer*(g: GossipSub, peer: PeerID) =
## handle peer disconnects ## handle peer disconnects
## ##
procCall FloodSub(g).handleDisconnect(peer) trace "unsubscribing gossipsub peer", peer = $peer
let pubSubPeer = g.peers.getOrDefault(peer)
if pubSubPeer.isNil:
return
for t in toSeq(g.gossipsub.keys):
g.gossipsub.removePeer(t, pubSubPeer)
if not(isNil(peer)) and peer.peerInfo notin g.conns:
for t in toSeq(g.gossipsub.keys):
g.gossipsub.removePeer(t, peer)
when defined(libp2p_expensive_metrics): when defined(libp2p_expensive_metrics):
libp2p_gossipsub_peers_per_topic_gossipsub libp2p_gossipsub_peers_per_topic_gossipsub
.set(g.gossipsub.peers(t).int64, labelValues = [t]) .set(g.gossipsub.peers(t).int64, labelValues = [t])
for t in toSeq(g.mesh.keys): for t in toSeq(g.mesh.keys):
if peer in g.mesh[t]: if peer in g.mesh[t]:
g.pruned(peer, t) g.pruned(peer, t)
g.mesh.removePeer(t, peer) g.mesh.removePeer(t, pubSubPeer)
when defined(libp2p_expensive_metrics): when defined(libp2p_expensive_metrics):
libp2p_gossipsub_peers_per_topic_mesh libp2p_gossipsub_peers_per_topic_mesh
.set(g.mesh.peers(t).int64, labelValues = [t]) .set(g.mesh.peers(t).int64, labelValues = [t])
for t in toSeq(g.fanout.keys): for t in toSeq(g.fanout.keys):
g.fanout.removePeer(t, peer) g.fanout.removePeer(t, pubSubPeer)
when defined(libp2p_expensive_metrics): when defined(libp2p_expensive_metrics):
libp2p_gossipsub_peers_per_topic_fanout libp2p_gossipsub_peers_per_topic_fanout
.set(g.fanout.peers(t).int64, labelValues = [t]) .set(g.fanout.peers(t).int64, labelValues = [t])
# TODO # TODO
# if peer.peerInfo.maintain: # if peer.peerInfo.maintain:
@ -644,19 +649,16 @@ method handleDisconnect*(g: GossipSub, peer: PubSubPeer) =
for topic, info in g.peerStats[peer].topicInfos.mpairs: for topic, info in g.peerStats[peer].topicInfos.mpairs:
info.firstMessageDeliveries = 0 info.firstMessageDeliveries = 0
method subscribePeer*(p: GossipSub, procCall FloodSub(g).unsubscribePeer(peer)
conn: Connection) =
procCall PubSub(p).subscribePeer(conn)
asyncCheck p.handleConn(conn, GossipSubCodec)
method subscribeTopic*(g: GossipSub, method subscribeTopic*(g: GossipSub,
topic: string, topic: string,
subscribe: bool, subscribe: bool,
peerId: string) {.gcsafe, async.} = peerId: PeerID) {.gcsafe, async.} =
await procCall FloodSub(g).subscribeTopic(topic, subscribe, peerId) await procCall FloodSub(g).subscribeTopic(topic, subscribe, peerId)
logScope: logScope:
peer = peerId peer = $peerId
topic topic
let peer = g.peers.getOrDefault(peerId) let peer = g.peers.getOrDefault(peerId)
@ -817,8 +819,8 @@ method rpcHandler*(g: GossipSub,
g.seen.put(msgId) # add the message to the seen cache g.seen.put(msgId) # add the message to the seen cache
if g.verifySignature and not msg.verify(peer.peerInfo): if g.verifySignature and not msg.verify(peer.peerId):
trace "dropping message due to failed signature verification", peer trace "dropping message due to failed signature verification"
g.punishPeer(peer, msg) g.punishPeer(peer, msg)
continue continue
@ -872,7 +874,10 @@ method rpcHandler*(g: GossipSub,
trace "exception in message handler", exc = exc.msg trace "exception in message handler", exc = exc.msg
# forward the message to all peers interested in it # forward the message to all peers interested in it
let published = await g.publishHelper(toSendPeers, m.messages, DefaultSendTimeout) let published = await g.broadcast(
toSeq(toSendPeers),
RPCMsg(messages: m.messages),
DefaultSendTimeout)
trace "forwared message to peers", peers = published trace "forwared message to peers", peers = published
@ -889,8 +894,10 @@ method rpcHandler*(g: GossipSub,
respControl.ihave.len > 0: respControl.ihave.len > 0:
try: try:
info "sending control message", msg = respControl info "sending control message", msg = respControl
await peer.send( await g.send(
RPCMsg(control: some(respControl), messages: messages)) peer,
RPCMsg(control: some(respControl), messages: messages),
DefaultSendTimeout)
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
@ -917,12 +924,10 @@ method unsubscribe*(g: GossipSub,
if topic in g.mesh: if topic in g.mesh:
let peers = g.mesh.getOrDefault(topic) let peers = g.mesh.getOrDefault(topic)
g.mesh.del(topic) g.mesh.del(topic)
var pending = newSeq[Future[void]]()
for peer in peers: for peer in peers:
g.pruned(peer, topic) g.pruned(peer, topic)
pending.add(peer.sendPrune(@[topic])) let prune = RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))
checkFutures(await allFinished(pending)) discard await g.broadcast(toSeq(peers), prune, DefaultSendTimeout)
method unsubscribeAll*(g: GossipSub, topic: string) {.async.} = method unsubscribeAll*(g: GossipSub, topic: string) {.async.} =
await procCall PubSub(g).unsubscribeAll(topic) await procCall PubSub(g).unsubscribeAll(topic)
@ -930,12 +935,10 @@ method unsubscribeAll*(g: GossipSub, topic: string) {.async.} =
if topic in g.mesh: if topic in g.mesh:
let peers = g.mesh.getOrDefault(topic) let peers = g.mesh.getOrDefault(topic)
g.mesh.del(topic) g.mesh.del(topic)
var pending = newSeq[Future[void]]()
for peer in peers: for peer in peers:
g.pruned(peer, topic) g.pruned(peer, topic)
pending.add(peer.sendPrune(@[topic])) let prune = RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))
checkFutures(await allFinished(pending)) discard await g.broadcast(toSeq(peers), prune, DefaultSendTimeout)
method publish*(g: GossipSub, method publish*(g: GossipSub,
topic: string, topic: string,
@ -986,7 +989,7 @@ method publish*(g: GossipSub,
if msgId notin g.mcache: if msgId notin g.mcache:
g.mcache.put(msgId, msg) g.mcache.put(msgId, msg)
let published = await g.publishHelper(peers, @[msg], timeout) let published = await g.broadcast(toSeq(peers), RPCMsg(messages: @[msg]), timeout)
when defined(libp2p_expensive_metrics): when defined(libp2p_expensive_metrics):
if published > 0: if published > 0:
libp2p_pubsub_messages_published.inc(labelValues = [topic]) libp2p_pubsub_messages_published.inc(labelValues = [topic])

View File

@ -13,10 +13,10 @@ import pubsubpeer, ../../peerid
type type
PeerTable* = Table[string, HashSet[PubSubPeer]] # topic string to peer map PeerTable* = Table[string, HashSet[PubSubPeer]] # topic string to peer map
proc hasPeerID*(t: PeerTable, topic, peerId: string): bool = proc hasPeerID*(t: PeerTable, topic: string, peerId: PeerID): bool =
let peers = toSeq(t.getOrDefault(topic)) let peers = toSeq(t.getOrDefault(topic))
peers.any do (peer: PubSubPeer) -> bool: peers.any do (peer: PubSubPeer) -> bool:
peer.id == peerId peer.peerId == peerId
func addPeer*(table: var PeerTable, topic: string, peer: PubSubPeer): bool = func addPeer*(table: var PeerTable, topic: string, peer: PubSubPeer): bool =
# returns true if the peer was added, # returns true if the peer was added,

View File

@ -11,6 +11,7 @@ import std/[tables, sequtils, sets]
import chronos, chronicles, metrics import chronos, chronicles, metrics
import pubsubpeer, import pubsubpeer,
rpc/[message, messages], rpc/[message, messages],
../../switch,
../protocol, ../protocol,
../../stream/connection, ../../stream/connection,
../../peerid, ../../peerid,
@ -53,64 +54,77 @@ type
handler*: seq[TopicHandler] handler*: seq[TopicHandler]
PubSub* = ref object of LPProtocol PubSub* = ref object of LPProtocol
switch*: Switch # the switch used to dial/connect to peers
peerInfo*: PeerInfo # this peer's info peerInfo*: PeerInfo # this peer's info
topics*: Table[string, Topic] # local topics topics*: Table[string, Topic] # local topics
peers*: Table[string, PubSubPeer] # peerid to peer map peers*: Table[PeerID, PubSubPeer] # peerid to peer map
conns*: Table[PeerInfo, HashSet[Connection]] # peers connections
triggerSelf*: bool # trigger own local handler on publish triggerSelf*: bool # trigger own local handler on publish
verifySignature*: bool # enable signature verification verifySignature*: bool # enable signature verification
sign*: bool # enable message signing sign*: bool # enable message signing
cleanupLock: AsyncLock cleanupLock: AsyncLock
validators*: Table[string, HashSet[ValidatorHandler]] validators*: Table[string, HashSet[ValidatorHandler]]
observers: ref seq[PubSubObserver] # ref as in smart_ptr observers: ref seq[PubSubObserver] # ref as in smart_ptr
msgIdProvider*: MsgIdProvider # Turn message into message id (not nil) msgIdProvider*: MsgIdProvider # Turn message into message id (not nil)
msgSeqno*: uint64 msgSeqno*: uint64
lifetimeFut*: Future[void] # pubsub liftime future
method handleConnect*(p: PubSub, peer: PubSubPeer) {.base.} = method unsubscribePeer*(p: PubSub, peerId: PeerID) {.base.} =
discard
method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} =
## handle peer disconnects ## handle peer disconnects
## ##
if not(isNil(peer)) and peer.peerInfo notin p.conns:
trace "deleting peer", peer = peer.id
peer.onConnect.fire() # Make sure all pending sends are unblocked
p.peers.del(peer.id)
trace "peer disconnected", peer = peer.id
# metrics trace "unsubscribing pubsub peer", peer = $peerId
libp2p_pubsub_peers.set(p.peers.len.int64) if peerId in p.peers:
p.peers.del(peerId)
proc onConnClose(p: PubSub, conn: Connection) {.async.} = libp2p_pubsub_peers.set(p.peers.len.int64)
proc send*(
p: PubSub,
peer: PubSubPeer,
msg: RPCMsg,
timeout: Duration) {.async.} =
## send to remote peer
##
trace "sending pubsub message to peer", peer = $peer, msg = msg
try: try:
let peer = conn.peerInfo await peer.send(msg, timeout)
await conn.closeEvent.wait()
if peer in p.conns:
p.conns[peer].excl(conn)
if p.conns[peer].len <= 0:
p.conns.del(peer)
if peer.id in p.peers:
p.handleDisconnect(p.peers[peer.id])
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
trace "exception in onConnClose handler", exc = exc.msg trace "exception sending pubsub message to peer", peer = $peer, msg = msg
p.unsubscribePeer(peer.peerId)
raise exc
proc broadcast*(
p: PubSub,
sendPeers: seq[PubSubPeer],
msg: RPCMsg,
timeout: Duration): Future[int] {.async.} =
## send messages and cleanup failed peers
##
trace "broadcasting messages to peers", peers = sendPeers.len, message = msg
let sent = await allFinished(
sendPeers.mapIt( p.send(it, msg, timeout) ))
return sent.filterIt( it.finished and it.error.isNil ).len
trace "messages broadcasted to peers", peers = sent.len
proc sendSubs*(p: PubSub, proc sendSubs*(p: PubSub,
peer: PubSubPeer, peer: PubSubPeer,
topics: seq[string], topics: seq[string],
subscribe: bool) {.async.} = subscribe: bool) {.async.} =
## send subscriptions to remote peer ## send subscriptions to remote peer
asyncCheck peer.sendSubOpts(topics, subscribe) await p.send(
peer,
RPCMsg(
subscriptions: topics.mapIt(SubOpts(subscribe: subscribe, topic: it))),
DefaultSendTimeout)
method subscribeTopic*(p: PubSub, method subscribeTopic*(p: PubSub,
topic: string, topic: string,
subscribe: bool, subscribe: bool,
peerId: string) {.base, async.} = peerId: PeerID) {.base, async.} =
# called when remote peer subscribes to a topic # called when remote peer subscribes to a topic
var peer = p.peers.getOrDefault(peerId) var peer = p.peers.getOrDefault(peerId)
if not isNil(peer): if not isNil(peer):
@ -130,27 +144,27 @@ method rpcHandler*(p: PubSub,
if m.subscriptions.len > 0: # if there are any subscriptions if m.subscriptions.len > 0: # if there are any subscriptions
for s in m.subscriptions: # subscribe/unsubscribe the peer for each topic for s in m.subscriptions: # subscribe/unsubscribe the peer for each topic
trace "about to subscribe to topic", topicId = s.topic trace "about to subscribe to topic", topicId = s.topic
await p.subscribeTopic(s.topic, s.subscribe, peer.id) await p.subscribeTopic(s.topic, s.subscribe, peer.peerId)
proc getOrCreatePeer(p: PubSub, proc getOrCreatePeer*(
peerInfo: PeerInfo, p: PubSub,
proto: string): PubSubPeer = peer: PeerID,
if peerInfo.id in p.peers: proto: string): PubSubPeer =
return p.peers[peerInfo.id] if peer in p.peers:
return p.peers[peer]
# create new pubsub peer # create new pubsub peer
let peer = newPubSubPeer(peerInfo, proto) let pubSubPeer = newPubSubPeer(peer, p.switch, proto)
trace "created new pubsub peer", peerId = peer.id trace "created new pubsub peer", peerId = $peer
p.peers[peer.id] = peer p.peers[peer] = pubSubPeer
peer.observers = p.observers pubSubPeer.observers = p.observers
handleConnect(p, peer) handleConnect(p, peer)
# metrics # metrics
libp2p_pubsub_peers.set(p.peers.len.int64) libp2p_pubsub_peers.set(p.peers.len.int64)
return pubSubPeer
return peer
method handleConn*(p: PubSub, method handleConn*(p: PubSub,
conn: Connection, conn: Connection,
@ -171,19 +185,11 @@ method handleConn*(p: PubSub,
await conn.close() await conn.close()
return return
# track connection
p.conns.mgetOrPut(conn.peerInfo,
initHashSet[Connection]())
.incl(conn)
asyncCheck p.onConnClose(conn)
proc handler(peer: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msgs: seq[RPCMsg]) {.async.} =
# call pubsub rpc handler # call pubsub rpc handler
await p.rpcHandler(peer, msgs) await p.rpcHandler(peer, msgs)
let peer = p.getOrCreatePeer(conn.peerInfo, proto) let peer = p.getOrCreatePeer(conn.peerInfo.peerId, proto)
if p.topics.len > 0: if p.topics.len > 0:
await p.sendSubs(peer, toSeq(p.topics.keys), true) await p.sendSubs(peer, toSeq(p.topics.keys), true)
@ -198,32 +204,16 @@ method handleConn*(p: PubSub,
finally: finally:
await conn.close() await conn.close()
method subscribePeer*(p: PubSub, conn: Connection) {.base.} = method subscribePeer*(p: PubSub, peer: PeerID) {.base.} =
if not(isNil(conn)): ## subscribe to remote peer to receive/send pubsub
trace "subscribing to peer", peerId = conn.peerInfo.id ## messages
##
# track connection let pubsubPeer = p.getOrCreatePeer(peer, p.codec)
p.conns.mgetOrPut(conn.peerInfo, if p.topics.len > 0:
initHashSet[Connection]()) asyncCheck p.sendSubs(pubsubPeer, toSeq(p.topics.keys), true)
.incl(conn)
asyncCheck p.onConnClose(conn) pubsubPeer.subscribed = true
let peer = p.getOrCreatePeer(conn.peerInfo, p.codec)
if not peer.connected:
peer.conn = conn
method unsubscribePeer*(p: PubSub, peerInfo: PeerInfo) {.base, async.} =
if peerInfo.id in p.peers:
let peer = p.peers[peerInfo.id]
trace "unsubscribing from peer", peerId = $peerInfo
if not(isNil(peer)) and not(isNil(peer.conn)):
await peer.conn.close()
proc connected*(p: PubSub, peerId: PeerID): bool =
p.peers.withValue($peerId, peer):
return peer[] != nil and peer[].connected
method unsubscribe*(p: PubSub, method unsubscribe*(p: PubSub,
topics: seq[TopicPair]) {.base, async.} = topics: seq[TopicPair]) {.base, async.} =
@ -278,40 +268,6 @@ method subscribe*(p: PubSub,
# metrics # metrics
libp2p_pubsub_topics.set(p.topics.len.int64) libp2p_pubsub_topics.set(p.topics.len.int64)
proc publishHelper*(p: PubSub,
sendPeers: HashSet[PubSubPeer],
msgs: seq[Message],
timeout: Duration): Future[int] {.async.} =
# send messages and cleanup failed peers
var sent: seq[tuple[id: string, fut: Future[void]]]
for sendPeer in sendPeers:
# avoid sending to self
if sendPeer.peerInfo == p.peerInfo:
continue
trace "sending messages to peer", peer = sendPeer.id, msgs
sent.add((id: sendPeer.id, fut: sendPeer.send(RPCMsg(messages: msgs), timeout)))
var published: seq[string]
var failed: seq[string]
let futs = await allFinished(sent.mapIt(it.fut))
for s in futs:
let f = sent.filterIt(it.fut == s)
if f.len > 0:
if s.failed:
trace "sending messages to peer failed", peer = f[0].id
failed.add(f[0].id)
else:
trace "sending messages to peer succeeded", peer = f[0].id
published.add(f[0].id)
for f in failed:
let peer = p.peers.getOrDefault(f)
if not(isNil(peer)) and not(isNil(peer.conn)):
await peer.conn.close()
return published.len
method publish*(p: PubSub, method publish*(p: PubSub,
topic: string, topic: string,
data: seq[byte], data: seq[byte],
@ -381,28 +337,35 @@ method validate*(p: PubSub, message: Message): Future[bool] {.async, base.} =
else: else:
libp2p_pubsub_validation_failure.inc() libp2p_pubsub_validation_failure.inc()
proc newPubSub*[PubParams: object | bool](P: typedesc[PubSub], proc init*[PubParams: object | bool](
peerInfo: PeerInfo, P: typedesc[PubSub],
triggerSelf: bool = false, switch: Switch,
verifySignature: bool = true, triggerSelf: bool = false,
sign: bool = true, verifySignature: bool = true,
msgIdProvider: MsgIdProvider = defaultMsgIdProvider, sign: bool = true,
params: PubParams = false): P = msgIdProvider: MsgIdProvider = defaultMsgIdProvider,
parameters: PubParams = false): P =
when PubParams is bool: when PubParams is bool:
result = P(peerInfo: peerInfo, result = P(switch: switch,
peerInfo: switch.peerInfo,
triggerSelf: triggerSelf, triggerSelf: triggerSelf,
verifySignature: verifySignature, verifySignature: verifySignature,
sign: sign, sign: sign,
peers: initTable[PeerID, PubSubPeer](),
topics: initTable[string, Topic](),
cleanupLock: newAsyncLock(), cleanupLock: newAsyncLock(),
msgIdProvider: msgIdProvider) msgIdProvider: msgIdProvider)
else: else:
result = P(peerInfo: peerInfo, result = P(switch: switch,
triggerSelf: triggerSelf, peerInfo: switch.peerInfo,
verifySignature: verifySignature, triggerSelf: triggerSelf,
sign: sign, verifySignature: verifySignature,
cleanupLock: newAsyncLock(), sign: sign,
msgIdProvider: msgIdProvider, peers: initTable[PeerID, PubSubPeer](),
parameters: params) topics: initTable[string, Topic](),
cleanupLock: newAsyncLock(),
msgIdProvider: msgIdProvider,
parameters: parameters)
result.initPubSub() result.initPubSub()
@ -412,6 +375,3 @@ proc removeObserver*(p: PubSub; observer: PubSubObserver) =
let idx = p.observers[].find(observer) let idx = p.observers[].find(observer)
if idx != -1: if idx != -1:
p.observers[].del(idx) p.observers[].del(idx)
proc connected*(p: PubSub, peerInfo: PeerInfo): bool {.deprecated: "Use PeerID version".} =
peerInfo != nil and connected(p, peerInfo.peerId)

View File

@ -11,6 +11,7 @@ import std/[hashes, options, sequtils, strutils, tables, hashes, sets]
import chronos, chronicles, nimcrypto/sha2, metrics import chronos, chronicles, nimcrypto/sha2, metrics
import rpc/[messages, message, protobuf], import rpc/[messages, message, protobuf],
timedcache, timedcache,
../../switch,
../../peerid, ../../peerid,
../../peerinfo, ../../peerinfo,
../../stream/connection, ../../stream/connection,
@ -28,7 +29,6 @@ when defined(libp2p_expensive_metrics):
declareCounter(libp2p_pubsub_skipped_sent_messages, "number of sent skipped messages", labels = ["id"]) declareCounter(libp2p_pubsub_skipped_sent_messages, "number of sent skipped messages", labels = ["id"])
const const
DefaultReadTimeout* = 1.minutes
DefaultSendTimeout* = 10.seconds DefaultSendTimeout* = 10.seconds
type type
@ -37,15 +37,17 @@ type
onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe, raises: [Defect].} onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe, raises: [Defect].}
PubSubPeer* = ref object of RootObj PubSubPeer* = ref object of RootObj
proto*: string # the protocol that this peer joined from switch*: Switch # switch instance to dial peers
codec*: string # the protocol that this peer joined from
sendConn: Connection sendConn: Connection
peerInfo*: PeerInfo peerId*: PeerID
handler*: RPCHandler handler*: RPCHandler
topics*: HashSet[string] topics*: HashSet[string]
sentRpcCache: TimedCache[string] # cache for already sent messages sentRpcCache: TimedCache[string] # cache for already sent messages
recvdRpcCache: TimedCache[string] # cache for already received messages recvdRpcCache: TimedCache[string] # cache for already received messages
onConnect*: AsyncEvent
observers*: ref seq[PubSubObserver] # ref as in smart_ptr observers*: ref seq[PubSubObserver] # ref as in smart_ptr
subscribed*: bool # are we subscribed to this peer
sendLock*: AsyncLock # send connection lock
score*: float64 score*: float64
@ -57,19 +59,13 @@ func hash*(p: PubSubPeer): Hash =
# int is either 32/64, so intptr basically, pubsubpeer is a ref # int is either 32/64, so intptr basically, pubsubpeer is a ref
cast[pointer](p).hash cast[pointer](p).hash
proc id*(p: PubSubPeer): string = p.peerInfo.id proc id*(p: PubSubPeer): string =
doAssert(not p.isNil, "nil pubsubpeer")
p.peerId.pretty
proc connected*(p: PubSubPeer): bool = proc connected*(p: PubSubPeer): bool =
not(isNil(p.sendConn)) not p.sendConn.isNil and not
(p.sendConn.closed or p.sendConn.atEof)
proc `conn=`*(p: PubSubPeer, conn: Connection) =
if not(isNil(conn)):
trace "attaching send connection for peer", peer = p.id
p.sendConn = conn
p.onConnect.fire()
proc conn*(p: PubSubPeer): Connection =
p.sendConn
proc recvObservers(p: PubSubPeer, msg: var RPCMsg) = proc recvObservers(p: PubSubPeer, msg: var RPCMsg) =
# trigger hooks # trigger hooks
@ -88,12 +84,13 @@ proc sendObservers(p: PubSubPeer, msg: var RPCMsg) =
proc handle*(p: PubSubPeer, conn: Connection) {.async.} = proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
logScope: logScope:
peer = p.id peer = p.id
debug "starting pubsub read loop for peer", closed = conn.closed debug "starting pubsub read loop for peer", closed = conn.closed
try: try:
try: try:
while not conn.atEof: while not conn.atEof:
trace "waiting for data", closed = conn.closed trace "waiting for data", closed = conn.closed
let data = await conn.readLp(64 * 1024).wait(DefaultReadTimeout) let data = await conn.readLp(64 * 1024)
let digest = $(sha256.digest(data)) let digest = $(sha256.digest(data))
trace "read data from peer", data = data.shortLog trace "read data from peer", data = data.shortLog
if digest in p.recvdRpcCache: if digest in p.recvdRpcCache:
@ -129,12 +126,14 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} =
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
trace "Exception occurred in PubSubPeer.handle", exc = exc.msg trace "Exception occurred in PubSubPeer.handle", exc = exc.msg
raise exc
proc send*( proc send*(
p: PubSubPeer, p: PubSubPeer,
msg: RPCMsg, msg: RPCMsg,
timeout: Duration = DefaultSendTimeout) {.async.} = timeout: Duration = DefaultSendTimeout) {.async.} =
doAssert(not isNil(p), "pubsubpeer nil!")
logScope: logScope:
peer = p.id peer = p.id
rpcMsg = shortLog(msg) rpcMsg = shortLog(msg)
@ -160,91 +159,55 @@ proc send*(
libp2p_pubsub_skipped_sent_messages.inc(labelValues = [p.id]) libp2p_pubsub_skipped_sent_messages.inc(labelValues = [p.id])
return return
proc sendToRemote() {.async.} =
logScope:
peer = p.id
rpcMsg = shortLog(msg)
trace "about to send message"
if not p.onConnect.isSet:
await p.onConnect.wait()
if p.connected: # this can happen if the remote disconnected
trace "sending encoded msgs to peer"
await p.sendConn.writeLp(encoded)
p.sentRpcCache.put(digest)
trace "sent pubsub message to remote"
when defined(libp2p_expensive_metrics):
for x in mm.messages:
for t in x.topicIDs:
# metrics
libp2p_pubsub_sent_messages.inc(labelValues = [p.id, t])
let sendFut = sendToRemote()
try: try:
await sendFut.wait(timeout) trace "about to send message"
if not p.connected:
try:
await p.sendLock.acquire()
trace "no send connection, dialing peer"
# get a send connection if there is none
p.sendConn = await p.switch.dial(
p.peerId, p.codec)
if not p.connected:
raise newException(CatchableError, "unable to get send pubsub stream")
# install a reader on the send connection
asyncCheck p.handle(p.sendConn)
finally:
if p.sendLock.locked:
p.sendLock.release()
trace "sending encoded msgs to peer"
await p.sendConn.writeLp(encoded).wait(timeout)
p.sentRpcCache.put(digest)
trace "sent pubsub message to remote"
when defined(libp2p_expensive_metrics):
for x in mm.messages:
for t in x.topicIDs:
# metrics
libp2p_pubsub_sent_messages.inc(labelValues = [p.id, t])
except CatchableError as exc: except CatchableError as exc:
trace "unable to send to remote", exc = exc.msg trace "unable to send to remote", exc = exc.msg
if not sendFut.finished:
sendFut.cancel()
if not(isNil(p.sendConn)): if not(isNil(p.sendConn)):
await p.sendConn.close() await p.sendConn.close()
p.sendConn = nil p.sendConn = nil
p.onConnect.clear()
raise exc raise exc
proc sendSubOpts*(p: PubSubPeer, topics: seq[string], subscribe: bool) {.async.} =
trace "sending subscriptions", peer = p.id, subscribe, topicIDs = topics
try:
await p.send(RPCMsg(
subscriptions: topics.mapIt(SubOpts(subscribe: subscribe, topic: it))),
# the long timeout is mostly for cases where
# the connection is flaky at the beggingin
timeout = 3.minutes)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception sending subscriptions", exc = exc.msg
proc sendGraft*(p: PubSubPeer, topics: seq[string]) {.async.} =
trace "sending graft to peer", peer = p.id, topicIDs = topics
try:
await p.send(RPCMsg(control: some(
ControlMessage(graft: topics.mapIt(ControlGraft(topicID: it))))),
timeout = 1.minutes)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception sending grafts", exc = exc.msg
proc sendPrune*(p: PubSubPeer, topics: seq[string]) {.async.} =
trace "sending prune to peer", peer = p.id, topicIDs = topics
try:
await p.send(RPCMsg(control: some(
ControlMessage(prune: topics.mapIt(ControlPrune(topicID: it))))),
timeout = 1.minutes)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception sending prunes", exc = exc.msg
proc `$`*(p: PubSubPeer): string = proc `$`*(p: PubSubPeer): string =
p.id p.id
proc newPubSubPeer*(peerInfo: PeerInfo, proc newPubSubPeer*(peerId: PeerID,
proto: string): PubSubPeer = switch: Switch,
codec: string): PubSubPeer =
new result new result
result.proto = proto result.switch = switch
result.peerInfo = peerInfo result.codec = codec
result.peerId = peerId
result.sentRpcCache = newTimedCache[string](2.minutes) result.sentRpcCache = newTimedCache[string](2.minutes)
result.recvdRpcCache = newTimedCache[string](2.minutes) result.recvdRpcCache = newTimedCache[string](2.minutes)
result.onConnect = newAsyncEvent()
result.topics = initHashSet[string]() result.topics = initHashSet[string]()
result.sendLock = newAsyncLock()

View File

@ -10,7 +10,8 @@
{.push raises: [Defect].} {.push raises: [Defect].}
import chronicles, metrics, stew/[byteutils, endians2] import chronicles, metrics, stew/[byteutils, endians2]
import ./messages, ./protobuf, import ./messages,
./protobuf,
../../../peerid, ../../../peerid,
../../../peerinfo, ../../../peerinfo,
../../../crypto/crypto, ../../../crypto/crypto,
@ -32,7 +33,7 @@ func defaultMsgIdProvider*(m: Message): string =
proc sign*(msg: Message, p: PeerInfo): CryptoResult[seq[byte]] = proc sign*(msg: Message, p: PeerInfo): CryptoResult[seq[byte]] =
ok((? p.privateKey.sign(PubSubPrefix & encodeMessage(msg))).getBytes()) ok((? p.privateKey.sign(PubSubPrefix & encodeMessage(msg))).getBytes())
proc verify*(m: Message, p: PeerInfo): bool = proc verify*(m: Message, p: PeerID): bool =
if m.signature.len > 0 and m.key.len > 0: if m.signature.len > 0 and m.key.len > 0:
var msg = m var msg = m
msg.signature = @[] msg.signature = @[]
@ -51,17 +52,17 @@ proc verify*(m: Message, p: PeerInfo): bool =
proc init*( proc init*(
T: type Message, T: type Message,
p: PeerInfo, peer: PeerInfo,
data: seq[byte], data: seq[byte],
topic: string, topic: string,
seqno: uint64, seqno: uint64,
sign: bool = true): Message {.gcsafe, raises: [CatchableError, Defect].} = sign: bool = true): Message {.gcsafe, raises: [CatchableError, Defect].} =
result = Message( result = Message(
fromPeer: p.peerId, fromPeer: peer.peerId,
data: data, data: data,
seqno: @(seqno.toBytesBE), # unefficient, fine for now seqno: @(seqno.toBytesBE), # unefficient, fine for now
topicIDs: @[topic]) topicIDs: @[topic])
if sign and p.publicKey.isSome: if sign and peer.publicKey.isSome:
result.signature = sign(result, p).tryGet() result.signature = sign(result, peer).tryGet()
result.key = p.publicKey.get().getBytes().tryGet() result.key = peer.publicKey.get().getBytes().tryGet()

View File

@ -30,11 +30,13 @@ type
proc init*[T: SecureConn](C: type T, proc init*[T: SecureConn](C: type T,
conn: Connection, conn: Connection,
peerInfo: PeerInfo, peerInfo: PeerInfo,
observedAddr: Multiaddress): T = observedAddr: Multiaddress,
timeout: Duration = DefaultConnectionTimeout): T =
result = C(stream: conn, result = C(stream: conn,
peerInfo: peerInfo, peerInfo: peerInfo,
observedAddr: observedAddr, observedAddr: observedAddr,
closeEvent: conn.closeEvent) closeEvent: conn.closeEvent,
timeout: timeout)
result.initStream() result.initStream()
method initStream*(s: SecureConn) = method initStream*(s: SecureConn) =
@ -62,7 +64,7 @@ proc handleConn*(s: Secure,
initiator: bool): Future[Connection] {.async, gcsafe.} = initiator: bool): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn, initiator) var sconn = await s.handshake(conn, initiator)
if not isNil(sconn): if not isNil(sconn):
conn.closeEvent.wait() conn.join()
.addCallback do(udata: pointer = nil): .addCallback do(udata: pointer = nil):
asyncCheck sconn.close() asyncCheck sconn.close()

View File

@ -1,16 +1,9 @@
# compile time options here
const
libp2p_pubsub_sign {.booldefine.} = true
libp2p_pubsub_verify {.booldefine.} = true
import import
options, tables, chronos, bearssl, options, tables, chronos, bearssl,
switch, peerid, peerinfo, stream/connection, multiaddress, switch, peerid, peerinfo, stream/connection, multiaddress,
crypto/crypto, transports/[transport, tcptransport], crypto/crypto, transports/[transport, tcptransport],
muxers/[muxer, mplex/mplex, mplex/types], muxers/[muxer, mplex/mplex, mplex/types],
protocols/[identify, secure/secure], protocols/[identify, secure/secure]
protocols/pubsub/[pubsub, floodsub, gossipsub],
protocols/pubsub/rpc/message
import import
protocols/secure/noise, protocols/secure/noise,
@ -26,17 +19,12 @@ type
proc newStandardSwitch*(privKey = none(PrivateKey), proc newStandardSwitch*(privKey = none(PrivateKey),
address = MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet(), address = MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet(),
triggerSelf = false,
gossip = false,
secureManagers: openarray[SecureProtocol] = [ secureManagers: openarray[SecureProtocol] = [
# array cos order matters # array cos order matters
SecureProtocol.Secio, SecureProtocol.Secio,
SecureProtocol.Noise, SecureProtocol.Noise,
], ],
verifySignature = libp2p_pubsub_verify,
sign = libp2p_pubsub_sign,
transportFlags: set[ServerFlags] = {}, transportFlags: set[ServerFlags] = {},
msgIdProvider: MsgIdProvider = defaultMsgIdProvider,
rng = newRng(), rng = newRng(),
inTimeout: Duration = 5.minutes, inTimeout: Duration = 5.minutes,
outTimeout: Duration = 5.minutes): Switch = outTimeout: Duration = 5.minutes): Switch =
@ -66,26 +54,11 @@ proc newStandardSwitch*(privKey = none(PrivateKey),
of SecureProtocol.Secio: of SecureProtocol.Secio:
secureManagerInstances &= newSecio(rng, seckey).Secure secureManagerInstances &= newSecio(rng, seckey).Secure
let pubSub = if gossip: let switch = newSwitch(
newPubSub(GossipSub,
peerInfo = peerInfo,
triggerSelf = triggerSelf,
verifySignature = verifySignature,
sign = sign,
msgIdProvider = msgIdProvider,
params = GossipSubParams.init()).PubSub
else:
newPubSub(FloodSub,
peerInfo = peerInfo,
triggerSelf = triggerSelf,
verifySignature = verifySignature,
sign = sign,
msgIdProvider = msgIdProvider).PubSub
newSwitch(
peerInfo, peerInfo,
transports, transports,
identify, identify,
muxers, muxers,
secureManagers = secureManagerInstances, secureManagers = secureManagerInstances)
pubSub = some(pubSub))
return switch

View File

@ -143,8 +143,10 @@ proc initBufferStream*(s: BufferStream,
trace "created bufferstream", oid = $s.oid trace "created bufferstream", oid = $s.oid
proc newBufferStream*(handler: WriteHandler = nil, proc newBufferStream*(handler: WriteHandler = nil,
size: int = DefaultBufferSize): BufferStream = size: int = DefaultBufferSize,
timeout: Duration = DefaultConnectionTimeout): BufferStream =
new result new result
result.timeout = timeout
result.initBufferStream(handler, size) result.initBufferStream(handler, size)
proc popFirst*(s: BufferStream): byte = proc popFirst*(s: BufferStream): byte =

View File

@ -45,7 +45,7 @@ template withExceptions(body: untyped) =
raise exc raise exc
except TransportIncompleteError: except TransportIncompleteError:
# for all intents and purposes this is an EOF # for all intents and purposes this is an EOF
raise newLPStreamEOFError() raise newLPStreamIncompleteError()
except TransportLimitError: except TransportLimitError:
raise newLPStreamLimitError() raise newLPStreamLimitError()
except TransportUseClosedError: except TransportUseClosedError:

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import hashes import hashes, oids
import chronicles, chronos, metrics import chronicles, chronos, metrics
import lpstream, import lpstream,
../multiaddress, ../multiaddress,
@ -20,7 +20,7 @@ logScope:
const const
ConnectionTrackerName* = "libp2p.connection" ConnectionTrackerName* = "libp2p.connection"
DefaultConnectionTimeout* = 1.minutes DefaultConnectionTimeout* = 5.minutes
type type
TimeoutHandler* = proc(): Future[void] {.gcsafe.} TimeoutHandler* = proc(): Future[void] {.gcsafe.}
@ -73,8 +73,15 @@ method initStream*(s: Connection) =
procCall LPStream(s).initStream() procCall LPStream(s).initStream()
s.closeEvent = newAsyncEvent() s.closeEvent = newAsyncEvent()
if isNil(s.timeoutHandler):
s.timeoutHandler = proc() {.async.} =
await s.close()
trace "timeout", timeout = $s.timeout.millis
doAssert(isNil(s.timerTaskFut)) doAssert(isNil(s.timerTaskFut))
s.timerTaskFut = s.timeoutMonitor() # doAssert(s.timeout > 0.millis)
if s.timeout > 0.millis:
s.timerTaskFut = s.timeoutMonitor()
inc getConnectionTracker().opened inc getConnectionTracker().opened

View File

@ -115,8 +115,12 @@ proc readExactly*(s: LPStream,
read += await s.readOnce(addr pbuffer[read], nbytes - read) read += await s.readOnce(addr pbuffer[read], nbytes - read)
if read < nbytes: if read < nbytes:
trace "incomplete data received", read if s.atEof:
raise newLPStreamIncompleteError() trace "couldn't read all bytes, stream EOF", expected = nbytes, read
raise newLPStreamEOFError()
else:
trace "couldn't read all bytes, incomplete data", expected = nbytes, read
raise newLPStreamIncompleteError()
proc readLine*(s: LPStream, proc readLine*(s: LPStream,
limit = 0, limit = 0,

View File

@ -25,12 +25,14 @@ import stream/connection,
protocols/secure/secure, protocols/secure/secure,
peerinfo, peerinfo,
protocols/identify, protocols/identify,
protocols/pubsub/pubsub,
muxers/muxer, muxers/muxer,
connmanager, connmanager,
peerid, peerid,
errors errors
chronicles.formatIt(PeerInfo): $it
chronicles.formatIt(PeerID): $it
logScope: logScope:
topics = "switch" topics = "switch"
@ -44,9 +46,6 @@ declareCounter(libp2p_dialed_peers, "dialed peers")
declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_dials, "failed dials")
declareCounter(libp2p_failed_upgrade, "peers failed upgrade") declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
const
MaxPubsubReconnectAttempts* = 10
type type
NoPubSubException* = object of CatchableError NoPubSubException* = object of CatchableError
@ -77,14 +76,8 @@ type
identity*: Identify identity*: Identify
streamHandler*: StreamHandler streamHandler*: StreamHandler
secureManagers*: seq[Secure] secureManagers*: seq[Secure]
pubSub*: Option[PubSub]
running: bool
dialLock: Table[PeerID, AsyncLock] dialLock: Table[PeerID, AsyncLock]
ConnEvents: Table[ConnEventKind, HashSet[ConnEventHandler]] ConnEvents: Table[ConnEventKind, HashSet[ConnEventHandler]]
pubsubMonitors: Table[PeerId, Future[void]]
proc newNoPubSubException(): ref NoPubSubException {.inline.} =
result = newException(NoPubSubException, "no pubsub provided!")
proc addConnEventHandler*(s: Switch, proc addConnEventHandler*(s: Switch,
handler: ConnEventHandler, kind: ConnEventKind) = handler: ConnEventHandler, kind: ConnEventKind) =
@ -111,23 +104,6 @@ proc triggerConnEvent(s: Switch, peerId: PeerID, event: ConnEvent) {.async, gcsa
warn "exception in trigger ConnEvents", exc = exc.msg warn "exception in trigger ConnEvents", exc = exc.msg
proc disconnect*(s: Switch, peerId: PeerID) {.async, gcsafe.} proc disconnect*(s: Switch, peerId: PeerID) {.async, gcsafe.}
proc subscribePeer*(s: Switch, peerId: PeerID) {.async, gcsafe.}
proc subscribePeerInternal(s: Switch, peerId: PeerID) {.async, gcsafe.}
proc cleanupPubSubPeer(s: Switch, conn: Connection) {.async.} =
try:
await conn.closeEvent.wait()
trace "about to cleanup pubsub peer"
if s.pubSub.isSome:
let fut = s.pubsubMonitors.getOrDefault(conn.peerInfo.peerId)
if not(isNil(fut)) and not(fut.finished):
fut.cancel()
await s.pubSub.get().unsubscribePeer(conn.peerInfo)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception cleaning pubsub peer", exc = exc.msg
proc isConnected*(s: Switch, peerId: PeerID): bool = proc isConnected*(s: Switch, peerId: PeerID): bool =
## returns true if the peer has one or more ## returns true if the peer has one or more
@ -295,7 +271,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
proc internalConnect(s: Switch, proc internalConnect(s: Switch,
peerId: PeerID, peerId: PeerID,
addrs: seq[MultiAddress]): Future[Connection] {.async.} = addrs: seq[MultiAddress]): Future[Connection] {.async.} =
logScope: peer = peerId logScope:
peer = peerId
if s.peerInfo.peerId == peerId: if s.peerInfo.peerId == peerId:
raise newException(CatchableError, "can't dial self!") raise newException(CatchableError, "can't dial self!")
@ -353,12 +330,12 @@ proc internalConnect(s: Switch,
libp2p_failed_upgrade.inc() libp2p_failed_upgrade.inc()
raise exc raise exc
doAssert not isNil(upgraded), "checked in upgradeOutgoing" doAssert not isNil(upgraded), "connection died after upgradeOutgoing"
s.connManager.storeOutgoing(upgraded) s.connManager.storeOutgoing(upgraded)
conn = upgraded conn = upgraded
trace "dial successful", trace "dial successful",
oid = $conn.oid, oid = $upgraded.oid,
peerInfo = shortLog(upgraded.peerInfo) peerInfo = shortLog(upgraded.peerInfo)
break break
finally: finally:
@ -381,14 +358,31 @@ proc internalConnect(s: Switch,
# unworthy and disconnects it # unworthy and disconnects it
raise newException(CatchableError, "Connection closed during handshake") raise newException(CatchableError, "Connection closed during handshake")
asyncCheck s.cleanupPubSubPeer(conn)
asyncCheck s.subscribePeer(peerId)
return conn return conn
proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} = proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} =
discard await s.internalConnect(peerId, addrs) discard await s.internalConnect(peerId, addrs)
proc negotiateStream(s: Switch, stream: Connection, proto: string): Future[Connection] {.async.} =
trace "Attempting to select remote", proto = proto,
streamOid = $stream.oid,
oid = $stream.oid
if not await s.ms.select(stream, proto):
await stream.close()
raise newException(CatchableError, "Unable to select sub-protocol" & proto)
return stream
proc dial*(s: Switch,
peerId: PeerID,
proto: string): Future[Connection] {.async.} =
let stream = await s.connmanager.getMuxedStream(peerId)
if stream.isNil:
raise newException(CatchableError, "Couldn't get muxed stream")
return await s.negotiateStream(stream, proto)
proc dial*(s: Switch, proc dial*(s: Switch,
peerId: PeerID, peerId: PeerID,
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
@ -409,14 +403,7 @@ proc dial*(s: Switch,
await conn.close() await conn.close()
raise newException(CatchableError, "Couldn't get muxed stream") raise newException(CatchableError, "Couldn't get muxed stream")
trace "Attempting to select remote", proto = proto, return await s.negotiateStream(stream, proto)
streamOid = $stream.oid,
oid = $conn.oid
if not await s.ms.select(stream, proto):
await stream.close()
raise newException(CatchableError, "Unable to select sub-protocol" & proto)
return stream
except CancelledError as exc: except CancelledError as exc:
trace "dial canceled" trace "dial canceled"
await cleanup() await cleanup()
@ -458,21 +445,12 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
s.peerInfo.addrs[i] = t.ma # update peer's address s.peerInfo.addrs[i] = t.ma # update peer's address
startFuts.add(server) startFuts.add(server)
if s.pubSub.isSome:
await s.pubSub.get().start()
debug "started libp2p node", peer = $s.peerInfo, addrs = s.peerInfo.addrs debug "started libp2p node", peer = $s.peerInfo, addrs = s.peerInfo.addrs
result = startFuts # listen for incoming connections result = startFuts # listen for incoming connections
proc stop*(s: Switch) {.async.} = proc stop*(s: Switch) {.async.} =
trace "stopping switch" trace "stopping switch"
# we want to report errors but we do not want to fail
# or crash here, cos we need to clean possibly MANY items
# and any following conn/transport won't be cleaned up
if s.pubSub.isSome:
await s.pubSub.get().stop()
# close and cleanup all connections # close and cleanup all connections
await s.connManager.close() await s.connManager.close()
@ -486,139 +464,6 @@ proc stop*(s: Switch) {.async.} =
trace "switch stopped" trace "switch stopped"
proc subscribePeerInternal(s: Switch, peerId: PeerID) {.async, gcsafe.} =
## Subscribe to pub sub peer
##
if s.pubSub.isSome and not s.pubSub.get().connected(peerId):
trace "about to subscribe to pubsub peer", peer = peerId
var stream: Connection
try:
stream = await s.connManager.getMuxedStream(peerId)
if isNil(stream):
trace "unable to subscribe to peer", peer = peerId
return
if not await s.ms.select(stream, s.pubSub.get().codec):
if not(isNil(stream)):
trace "couldn't select pubsub", codec = s.pubSub.get().codec
await stream.close()
return
s.pubSub.get().subscribePeer(stream)
await stream.closeEvent.wait()
except CancelledError as exc:
if not(isNil(stream)):
await stream.close()
raise exc
except CatchableError as exc:
trace "exception in subscribe to peer", peer = peerId,
exc = exc.msg
if not(isNil(stream)):
await stream.close()
proc pubsubMonitor(s: Switch, peerId: PeerID) {.async.} =
## while peer connected maintain a
## pubsub connection as well
##
while s.isConnected(peerId):
try:
trace "subscribing to pubsub peer", peer = peerId
await s.subscribePeerInternal(peerId)
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "exception in pubsub monitor", peer = peerId, exc = exc.msg
finally:
trace "sleeping before trying pubsub peer", peer = peerId
await sleepAsync(1.seconds) # allow the peer to cooldown
trace "exiting pubsub monitor", peer = peerId
proc subscribePeer*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} =
## Waits until ``server`` is not closed.
##
var retFuture = newFuture[void]("stream.transport.server.join")
let pubsubFut = s.pubsubMonitors.mgetOrPut(
peerId, s.pubsubMonitor(peerId))
proc continuation(udata: pointer) {.gcsafe.} =
retFuture.complete()
proc cancel(udata: pointer) {.gcsafe.} =
pubsubFut.removeCallback(continuation, cast[pointer](retFuture))
if not(pubsubFut.finished()):
pubsubFut.addCallback(continuation, cast[pointer](retFuture))
retFuture.cancelCallback = cancel
else:
retFuture.complete()
return retFuture
proc subscribe*(s: Switch, topic: string,
handler: TopicHandler) {.async.} =
## subscribe to a pubsub topic
##
if s.pubSub.isNone:
raise newNoPubSubException()
await s.pubSub.get().subscribe(topic, handler)
proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} =
## unsubscribe from topics
##
if s.pubSub.isNone:
raise newNoPubSubException()
await s.pubSub.get().unsubscribe(topics)
proc unsubscribeAll*(s: Switch, topic: string) {.async.} =
## unsubscribe from topics
if s.pubSub.isNone:
raise newNoPubSubException()
await s.pubSub.get().unsubscribeAll(topic)
proc publish*(s: Switch,
topic: string,
data: seq[byte],
timeout: Duration = InfiniteDuration): Future[int] {.async.} =
## pubslish to pubsub topic
##
if s.pubSub.isNone:
raise newNoPubSubException()
return await s.pubSub.get().publish(topic, data, timeout)
proc addValidator*(s: Switch,
topics: varargs[string],
hook: ValidatorHandler) =
## add validator
##
if s.pubSub.isNone:
raise newNoPubSubException()
s.pubSub.get().addValidator(topics, hook)
proc removeValidator*(s: Switch,
topics: varargs[string],
hook: ValidatorHandler) =
## pubslish to pubsub topic
##
if s.pubSub.isNone:
raise newNoPubSubException()
s.pubSub.get().removeValidator(topics, hook)
proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
var stream = await muxer.newStream() var stream = await muxer.newStream()
defer: defer:
@ -654,10 +499,6 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
asyncCheck s.triggerConnEvent( asyncCheck s.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: true)) peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: true))
# try establishing a pubsub connection
asyncCheck s.cleanupPubSubPeer(muxer.connection)
asyncCheck s.subscribePeer(peerId)
except CancelledError as exc: except CancelledError as exc:
await muxer.close() await muxer.close()
raise exc raise exc
@ -670,8 +511,7 @@ proc newSwitch*(peerInfo: PeerInfo,
transports: seq[Transport], transports: seq[Transport],
identity: Identify, identity: Identify,
muxers: Table[string, MuxerProvider], muxers: Table[string, MuxerProvider],
secureManagers: openarray[Secure] = [], secureManagers: openarray[Secure] = []): Switch =
pubSub: Option[PubSub] = none(PubSub)): Switch =
if secureManagers.len == 0: if secureManagers.len == 0:
raise (ref CatchableError)(msg: "Provide at least one secure manager") raise (ref CatchableError)(msg: "Provide at least one secure manager")
@ -704,24 +544,21 @@ proc newSwitch*(peerInfo: PeerInfo,
val.muxerHandler = proc(muxer: Muxer): Future[void] = val.muxerHandler = proc(muxer: Muxer): Future[void] =
s.muxerHandler(muxer) s.muxerHandler(muxer)
if pubSub.isSome: proc isConnected*(s: Switch, peerInfo: PeerInfo): bool
result.pubSub = pubSub {.deprecated: "Use PeerID version".} =
result.mount(pubSub.get())
proc isConnected*(s: Switch, peerInfo: PeerInfo): bool {.deprecated: "Use PeerID version".} =
not isNil(peerInfo) and isConnected(s, peerInfo.peerId) not isNil(peerInfo) and isConnected(s, peerInfo.peerId)
proc disconnect*(s: Switch, peerInfo: PeerInfo): Future[void] {.deprecated: "Use PeerID version", gcsafe.} = proc disconnect*(s: Switch, peerInfo: PeerInfo): Future[void]
{.deprecated: "Use PeerID version", gcsafe.} =
disconnect(s, peerInfo.peerId) disconnect(s, peerInfo.peerId)
proc connect*(s: Switch, peerInfo: PeerInfo): Future[void] {.deprecated: "Use PeerID version".} = proc connect*(s: Switch, peerInfo: PeerInfo): Future[void]
{.deprecated: "Use PeerID version".} =
connect(s, peerInfo.peerId, peerInfo.addrs) connect(s, peerInfo.peerId, peerInfo.addrs)
proc dial*(s: Switch, proc dial*(s: Switch,
peerInfo: PeerInfo, peerInfo: PeerInfo,
proto: string): proto: string):
Future[Connection] {.deprecated: "Use PeerID version".} = Future[Connection]
{.deprecated: "Use PeerID version".} =
dial(s, peerInfo.peerId, peerInfo.addrs, proto) dial(s, peerInfo.peerId, peerInfo.addrs, proto)
proc subscribePeer*(s: Switch, peerInfo: PeerInfo): Future[void] {.deprecated: "Use PeerID version", gcsafe.} =
subscribePeer(s, peerInfo.peerId)

View File

@ -29,9 +29,9 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
# turn things deterministic # turn things deterministic
# this is for testing purposes only # this is for testing purposes only
var ceil = 15 var ceil = 15
let fsub = cast[FloodSub](sender.pubSub.get()) let fsub = cast[FloodSub](sender)
while not fsub.floodsub.hasKey(key) or while not fsub.floodsub.hasKey(key) or
not fsub.floodsub.hasPeerID(key, receiver.peerInfo.id): not fsub.floodsub.hasPeerID(key, receiver.peerInfo.peerId):
await sleepAsync(100.millis) await sleepAsync(100.millis)
dec ceil dec ceil
doAssert(ceil > 0, "waitSub timeout!") doAssert(ceil > 0, "waitSub timeout!")
@ -43,7 +43,7 @@ suite "FloodSub":
check tracker.isLeaked() == false check tracker.isLeaked() == false
test "FloodSub basic publish/subscribe A -> B": test "FloodSub basic publish/subscribe A -> B":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var completionFut = newFuture[bool]() var completionFut = newFuture[bool]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" check topic == "foobar"
@ -51,19 +51,32 @@ suite "FloodSub":
let let
nodes = generateNodes(2) nodes = generateNodes(2)
# start switches
nodesFut = await allFinished( nodesFut = await allFinished(
nodes[0].start(), nodes[0].switch.start(),
nodes[1].start() nodes[1].switch.start(),
) )
let subscribes = await subscribeNodes(nodes) # start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar") await waitSub(nodes[0], nodes[1], "foobar")
check (await nodes[0].publish("foobar", "Hello!".toBytes())) > 0 check (await nodes[0].publish("foobar", "Hello!".toBytes())) > 0
check (await completionFut.wait(5.seconds)) == true
result = await completionFut.wait(5.seconds) await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].stop(),
@ -71,53 +84,80 @@ suite "FloodSub":
) )
await allFuturesThrowing(nodesFut.concat()) await allFuturesThrowing(nodesFut.concat())
await allFuturesThrowing(subscribes)
check: waitFor(runTests())
waitFor(runTests()) == true
test "FloodSub basic publish/subscribe B -> A": test "FloodSub basic publish/subscribe B -> A":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var completionFut = newFuture[bool]() var completionFut = newFuture[bool]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" check topic == "foobar"
completionFut.complete(true) completionFut.complete(true)
var nodes = generateNodes(2) let
var awaiters: seq[Future[void]] nodes = generateNodes(2)
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start()))
let subscribes = await subscribeNodes(nodes) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsubcon
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[0].subscribe("foobar", handler) await nodes[0].subscribe("foobar", handler)
await waitSub(nodes[1], nodes[0], "foobar") await waitSub(nodes[1], nodes[0], "foobar")
check (await nodes[1].publish("foobar", "Hello!".toBytes())) > 0 check (await nodes[1].publish("foobar", "Hello!".toBytes())) > 0
result = await completionFut.wait(5.seconds) check (await completionFut.wait(5.seconds)) == true
await allFuturesThrowing(nodes[0].stop(), nodes[1].stop()) await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(
await allFuturesThrowing(awaiters) nodes[0].stop(),
nodes[1].stop()
)
check: await allFuturesThrowing(nodesFut)
waitFor(runTests()) == true
waitFor(runTests())
test "FloodSub validation should succeed": test "FloodSub validation should succeed":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var handlerFut = newFuture[bool]() var handlerFut = newFuture[bool]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" check topic == "foobar"
handlerFut.complete(true) handlerFut.complete(true)
var nodes = generateNodes(2) let
var awaiters: seq[Future[void]] nodes = generateNodes(2)
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start())) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsubcon
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
let subscribes = await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar") await waitSub(nodes[0], nodes[1], "foobar")
@ -131,30 +171,44 @@ suite "FloodSub":
nodes[1].addValidator("foobar", validator) nodes[1].addValidator("foobar", validator)
check (await nodes[0].publish("foobar", "Hello!".toBytes())) > 0 check (await nodes[0].publish("foobar", "Hello!".toBytes())) > 0
check (await handlerFut) == true check (await handlerFut) == true
await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].stop(),
nodes[1].stop()) nodes[1].stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut)
await allFuturesThrowing(awaiters)
result = true
check: waitFor(runTests())
waitFor(runTests()) == true
test "FloodSub validation should fail": test "FloodSub validation should fail":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check false # if we get here, it should fail check false # if we get here, it should fail
var nodes = generateNodes(2) let
var awaiters: seq[Future[void]] nodes = generateNodes(2)
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start()))
let subscribes = await subscribeNodes(nodes) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsubcon
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar") await waitSub(nodes[0], nodes[1], "foobar")
@ -168,30 +222,44 @@ suite "FloodSub":
discard await nodes[0].publish("foobar", "Hello!".toBytes()) discard await nodes[0].publish("foobar", "Hello!".toBytes())
await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].stop(),
nodes[1].stop()) nodes[1].stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut)
await allFuturesThrowing(awaiters)
result = true
check: waitFor(runTests())
waitFor(runTests()) == true
test "FloodSub validation one fails and one succeeds": test "FloodSub validation one fails and one succeeds":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var handlerFut = newFuture[bool]() var handlerFut = newFuture[bool]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foo" check topic == "foo"
handlerFut.complete(true) handlerFut.complete(true)
var nodes = generateNodes(2) let
var awaiters: seq[Future[void]] nodes = generateNodes(2)
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start()))
let subscribes = await subscribeNodes(nodes) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsubcon
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[1].subscribe("foo", handler) await nodes[1].subscribe("foo", handler)
await waitSub(nodes[0], nodes[1], "foo") await waitSub(nodes[0], nodes[1], "foo")
await nodes[1].subscribe("bar", handler) await nodes[1].subscribe("bar", handler)
@ -210,57 +278,21 @@ suite "FloodSub":
check (await nodes[0].publish("bar", "Hello!".toBytes())) > 0 check (await nodes[0].publish("bar", "Hello!".toBytes())) > 0
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].switch.stop(),
nodes[1].stop()) nodes[1].switch.stop()
)
await allFuturesThrowing(subscribes)
await allFuturesThrowing(awaiters)
result = true
check:
waitFor(runTests()) == true
test "FloodSub publish should fail on timeout":
proc runTests(): Future[bool] {.async.} =
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
discard
var nodes = generateNodes(2)
var awaiters: seq[Future[void]]
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start()))
let subscribes = await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar")
let pubsub = nodes[0].pubSub.get()
let peer = pubsub.peers[nodes[1].peerInfo.id]
peer.conn = Connection(newBufferStream(
proc (data: seq[byte]) {.async, gcsafe.} =
await sleepAsync(10.seconds)
,size = 0))
let in10millis = Moment.fromNow(10.millis)
let sent = await nodes[0].publish("foobar", "Hello!".toBytes(), 10.millis)
check Moment.now() >= in10millis
check sent == 0
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].stop(),
nodes[1].stop()) nodes[1].stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut)
await allFuturesThrowing(awaiters)
result = true
check: waitFor(runTests())
waitFor(runTests()) == true
test "FloodSub multiple peers, no self trigger": test "FloodSub multiple peers, no self trigger":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var runs = 10 var runs = 10
var futs = newSeq[(Future[void], TopicHandler, ref int)](runs) var futs = newSeq[(Future[void], TopicHandler, ref int)](runs)
@ -279,15 +311,12 @@ suite "FloodSub":
counter counter
) )
var nodes: seq[Switch] = newSeq[Switch]() let
for i in 0..<runs: nodes = generateNodes(runs, triggerSelf = false)
nodes.add newStandardSwitch(secureManagers = [SecureProtocol.Noise]) nodesFut = nodes.mapIt(it.switch.start())
var awaitters: seq[Future[void]] await allFuturesThrowing(nodes.mapIt(it.start()))
for i in 0..<runs: await subscribeNodes(nodes)
awaitters.add(await nodes[i].start())
let subscribes = await subscribeNodes(nodes)
for i in 0..<runs: for i in 0..<runs:
await nodes[i].subscribe("foobar", futs[i][1]) await nodes[i].subscribe("foobar", futs[i][1])
@ -305,17 +334,18 @@ suite "FloodSub":
await allFuturesThrowing(pubs) await allFuturesThrowing(pubs)
await allFuturesThrowing(futs.mapIt(it[0])) await allFuturesThrowing(futs.mapIt(it[0]))
await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(
nodes.mapIt(
allFutures(
it.stop(),
it.switch.stop())))
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut)
await allFuturesThrowing(awaitters)
result = true waitFor(runTests())
check:
waitFor(runTests()) == true
test "FloodSub multiple peers, with self trigger": test "FloodSub multiple peers, with self trigger":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var runs = 10 var runs = 10
var futs = newSeq[(Future[void], TopicHandler, ref int)](runs) var futs = newSeq[(Future[void], TopicHandler, ref int)](runs)
@ -329,21 +359,17 @@ suite "FloodSub":
(proc(topic: string, data: seq[byte]) {.async, gcsafe.} = (proc(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" check topic == "foobar"
inc counter[] inc counter[]
if counter[] == runs: if counter[] == runs - 1:
fut.complete()), fut.complete()),
counter counter
) )
var nodes: seq[Switch] = newSeq[Switch]() let
for i in 0..<runs: nodes = generateNodes(runs, triggerSelf = true)
nodes.add newStandardSwitch(triggerSelf = true, secureManagers = [SecureProtocol.Secio]) nodesFut = nodes.mapIt(it.switch.start())
await allFuturesThrowing(nodes.mapIt(it.start()))
var awaitters: seq[Future[void]] await subscribeNodes(nodes)
for i in 0..<runs:
awaitters.add(await nodes[i].start())
let subscribes = await subscribeNodes(nodes)
for i in 0..<runs: for i in 0..<runs:
await nodes[i].subscribe("foobar", futs[i][1]) await nodes[i].subscribe("foobar", futs[i][1])
@ -361,12 +387,12 @@ suite "FloodSub":
await allFuturesThrowing(pubs) await allFuturesThrowing(pubs)
await allFuturesThrowing(futs.mapIt(it[0])) await allFuturesThrowing(futs.mapIt(it[0]))
await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(
nodes.mapIt(
allFutures(
it.stop(),
it.switch.stop())))
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut)
await allFuturesThrowing(awaitters)
result = true waitFor(runTests())
check:
waitFor(runTests()) == true

View File

@ -4,6 +4,7 @@ include ../../libp2p/protocols/pubsub/gossipsub
import unittest, bearssl import unittest, bearssl
import stew/byteutils import stew/byteutils
import ../../libp2p/standard_setup
import ../../libp2p/errors import ../../libp2p/errors
import ../../libp2p/crypto/crypto import ../../libp2p/crypto/crypto
import ../../libp2p/stream/bufferstream import ../../libp2p/stream/bufferstream
@ -38,7 +39,7 @@ suite "GossipSub internal":
test "`rebalanceMesh` Degree Lo": test "`rebalanceMesh` Degree Lo":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
let topic = "foobar" let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[PubSubPeer]() gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
@ -50,11 +51,8 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = randomPeerInfo() let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipSub.switch, GossipSubCodec)
peer.conn = conn gossipSub.peers[peerInfo.peerId] = peer
gossipSub.peers[peerInfo.id] = peer
gossipSub.handleConnect(peer)
gossipSub.grafted(peer, topic)
gossipSub.mesh[topic].incl(peer) gossipSub.mesh[topic].incl(peer)
check gossipSub.peers.len == 15 check gossipSub.peers.len == 15
@ -62,7 +60,7 @@ suite "GossipSub internal":
check gossipSub.mesh[topic].len == GossipSubD check gossipSub.mesh[topic].len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
check: check:
@ -70,7 +68,7 @@ suite "GossipSub internal":
test "`rebalanceMesh` Degree Hi": test "`rebalanceMesh` Degree Hi":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
let topic = "foobar" let topic = "foobar"
gossipSub.mesh[topic] = initHashSet[PubSubPeer]() gossipSub.mesh[topic] = initHashSet[PubSubPeer]()
@ -83,11 +81,8 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipsub.switch, GossipSubCodec)
peer.conn = conn gossipSub.peers[peerInfo.peerId] = peer
gossipSub.peers[peerInfo.id] = peer
gossipSub.handleConnect(peer)
gossipSub.grafted(peer, topic)
gossipSub.mesh[topic].incl(peer) gossipSub.mesh[topic].incl(peer)
check gossipSub.mesh[topic].len == 15 check gossipSub.mesh[topic].len == 15
@ -95,6 +90,7 @@ suite "GossipSub internal":
check gossipSub.mesh[topic].len == GossipSubD check gossipSub.mesh[topic].len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
@ -103,7 +99,7 @@ suite "GossipSub internal":
test "`replenishFanout` Degree Lo": test "`replenishFanout` Degree Lo":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} =
discard discard
@ -117,7 +113,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
var peerInfo = randomPeerInfo() var peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipsub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
gossipSub.gossipsub[topic].incl(peer) gossipSub.gossipsub[topic].incl(peer)
@ -126,6 +122,7 @@ suite "GossipSub internal":
check gossipSub.fanout[topic].len == GossipSubD check gossipSub.fanout[topic].len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
@ -134,7 +131,7 @@ suite "GossipSub internal":
test "`dropFanoutPeers` drop expired fanout topics": test "`dropFanoutPeers` drop expired fanout topics":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} =
discard discard
@ -150,7 +147,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipsub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
gossipSub.fanout[topic].incl(peer) gossipSub.fanout[topic].incl(peer)
@ -160,6 +157,7 @@ suite "GossipSub internal":
check topic notin gossipSub.fanout check topic notin gossipSub.fanout
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
@ -168,7 +166,7 @@ suite "GossipSub internal":
test "`dropFanoutPeers` leave unexpired fanout topics": test "`dropFanoutPeers` leave unexpired fanout topics":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} =
discard discard
@ -187,7 +185,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = randomPeerInfo() let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipsub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
gossipSub.fanout[topic1].incl(peer) gossipSub.fanout[topic1].incl(peer)
gossipSub.fanout[topic2].incl(peer) gossipSub.fanout[topic2].incl(peer)
@ -200,6 +198,7 @@ suite "GossipSub internal":
check topic2 in gossipSub.fanout check topic2 in gossipSub.fanout
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
@ -208,7 +207,7 @@ suite "GossipSub internal":
test "`getGossipPeers` - should gather up to degree D non intersecting peers": test "`getGossipPeers` - should gather up to degree D non intersecting peers":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} =
discard discard
@ -225,7 +224,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = randomPeerInfo() let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipsub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
if i mod 2 == 0: if i mod 2 == 0:
gossipSub.fanout[topic].incl(peer) gossipSub.fanout[topic].incl(peer)
@ -238,7 +237,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = randomPeerInfo() let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipsub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
gossipSub.gossipsub[topic].incl(peer) gossipSub.gossipsub[topic].incl(peer)
@ -260,10 +259,11 @@ suite "GossipSub internal":
let peers = gossipSub.getGossipPeers() let peers = gossipSub.getGossipPeers()
check peers.len == GossipSubD check peers.len == GossipSubD
for p in peers.keys: for p in peers.keys:
check not gossipSub.fanout.hasPeerID(topic, p) check not gossipSub.fanout.hasPeerID(topic, p.peerId)
check not gossipSub.mesh.hasPeerID(topic, p) check not gossipSub.mesh.hasPeerID(topic, p.peerId)
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
@ -272,7 +272,7 @@ suite "GossipSub internal":
test "`getGossipPeers` - should not crash on missing topics in mesh": test "`getGossipPeers` - should not crash on missing topics in mesh":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} =
discard discard
@ -286,7 +286,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = randomPeerInfo() let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipsub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
if i mod 2 == 0: if i mod 2 == 0:
gossipSub.fanout[topic].incl(peer) gossipSub.fanout[topic].incl(peer)
@ -308,6 +308,7 @@ suite "GossipSub internal":
check peers.len == GossipSubD check peers.len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
@ -316,7 +317,7 @@ suite "GossipSub internal":
test "`getGossipPeers` - should not crash on missing topics in fanout": test "`getGossipPeers` - should not crash on missing topics in fanout":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} =
discard discard
@ -330,7 +331,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = randomPeerInfo() let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipSub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
if i mod 2 == 0: if i mod 2 == 0:
gossipSub.mesh[topic].incl(peer) gossipSub.mesh[topic].incl(peer)
@ -352,6 +353,7 @@ suite "GossipSub internal":
check peers.len == GossipSubD check peers.len == GossipSubD
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true
@ -360,7 +362,7 @@ suite "GossipSub internal":
test "`getGossipPeers` - should not crash on missing topics in gossip": test "`getGossipPeers` - should not crash on missing topics in gossip":
proc testRun(): Future[bool] {.async.} = proc testRun(): Future[bool] {.async.} =
let gossipSub = newPubSub(TestGossipSub, randomPeerInfo(), params = params) let gossipSub = TestGossipSub.init(newStandardSwitch(parameters = params))
proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} = proc handler(peer: PubSubPeer, msg: seq[RPCMsg]) {.async.} =
discard discard
@ -374,7 +376,7 @@ suite "GossipSub internal":
conns &= conn conns &= conn
let peerInfo = randomPeerInfo() let peerInfo = randomPeerInfo()
conn.peerInfo = peerInfo conn.peerInfo = peerInfo
let peer = newPubSubPeer(peerInfo, GossipSubCodec) let peer = newPubSubPeer(peerInfo.peerId, gossipSub.switch, GossipSubCodec)
peer.handler = handler peer.handler = handler
if i mod 2 == 0: if i mod 2 == 0:
gossipSub.mesh[topic].incl(peer) gossipSub.mesh[topic].incl(peer)
@ -396,6 +398,7 @@ suite "GossipSub internal":
check peers.len == 0 check peers.len == 0
await allFuturesThrowing(conns.mapIt(it.close())) await allFuturesThrowing(conns.mapIt(it.close()))
await gossipSub.switch.stop()
result = true result = true

View File

@ -33,7 +33,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
# this is for testing purposes only # this is for testing purposes only
# peers can be inside `mesh` and `fanout`, not just `gossipsub` # peers can be inside `mesh` and `fanout`, not just `gossipsub`
var ceil = 15 var ceil = 15
let fsub = GossipSub(sender.pubSub.get()) let fsub = GossipSub(sender)
let ev = newAsyncEvent() let ev = newAsyncEvent()
fsub.heartbeatEvents.add(ev) fsub.heartbeatEvents.add(ev)
@ -42,11 +42,11 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} =
ev.clear() ev.clear()
while (not fsub.gossipsub.hasKey(key) or while (not fsub.gossipsub.hasKey(key) or
not fsub.gossipsub.hasPeerID(key, receiver.peerInfo.id)) and not fsub.gossipsub.hasPeerID(key, receiver.peerInfo.peerId)) and
(not fsub.mesh.hasKey(key) or (not fsub.mesh.hasKey(key) or
not fsub.mesh.hasPeerID(key, receiver.peerInfo.id)) and not fsub.mesh.hasPeerID(key, receiver.peerInfo.peerId)) and
(not fsub.fanout.hasKey(key) or (not fsub.fanout.hasKey(key) or
not fsub.fanout.hasPeerID(key , receiver.peerInfo.id)): not fsub.fanout.hasPeerID(key , receiver.peerInfo.peerId)):
trace "waitSub sleeping..." trace "waitSub sleeping..."
# await more heartbeats # await more heartbeats
@ -74,18 +74,29 @@ suite "GossipSub":
check tracker.isLeaked() == false check tracker.isLeaked() == false
test "GossipSub validation should succeed": test "GossipSub validation should succeed":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var handlerFut = newFuture[bool]() var handlerFut = newFuture[bool]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" check topic == "foobar"
handlerFut.complete(true) handlerFut.complete(true)
var nodes = generateNodes(2, true) let
var awaiters: seq[Future[void]] nodes = generateNodes(2, gossip = true)
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start()))
let subscribes = await subscribeNodes(nodes) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[0].subscribe("foobar", handler) await nodes[0].subscribe("foobar", handler)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
@ -107,35 +118,44 @@ suite "GossipSub":
nodes[1].addValidator("foobar", validator) nodes[1].addValidator("foobar", validator)
tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1
result = (await validatorFut) and (await handlerFut) check (await validatorFut) and (await handlerFut)
let gossip1 = GossipSub(nodes[0].pubSub.get()) await allFuturesThrowing(
let gossip2 = GossipSub(nodes[1].pubSub.get()) nodes[0].switch.stop(),
check: nodes[1].switch.stop()
gossip1.mesh["foobar"].len == 1 and "foobar" notin gossip1.fanout )
gossip2.mesh["foobar"].len == 1 and "foobar" notin gossip2.fanout
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].stop(),
nodes[1].stop()) nodes[1].stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut.concat())
await allFuturesThrowing(awaiters)
check: waitFor(runTests())
waitFor(runTests()) == true
test "GossipSub validation should fail": test "GossipSub validation should fail":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check false # if we get here, it should fail check false # if we get here, it should fail
var nodes = generateNodes(2, true) let
var awaiters: seq[Future[void]] nodes = generateNodes(2, gossip = true)
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start()))
let subscribes = await subscribeNodes(nodes) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[0].subscribe("foobar", handler) await nodes[0].subscribe("foobar", handler)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
@ -163,7 +183,10 @@ suite "GossipSub":
nodes[1].addValidator("foobar", validator) nodes[1].addValidator("foobar", validator)
tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1
result = await validatorFut check (await validatorFut) == true
let gossip1 = GossipSub(nodes[0])
let gossip2 = GossipSub(nodes[1])
# gossip 1.1, gossip1 peer with negative score will be pruned in gossip2, # gossip 1.1, gossip1 peer with negative score will be pruned in gossip2,
# and so mesh will be empty # and so mesh will be empty
@ -181,29 +204,45 @@ suite "GossipSub":
gossip1.mesh["foobar"].len == 1 and "foobar" notin gossip1.fanout gossip1.mesh["foobar"].len == 1 and "foobar" notin gossip1.fanout
"foobar" notin gossip2.mesh and "foobar" notin gossip2.fanout "foobar" notin gossip2.mesh and "foobar" notin gossip2.fanout
await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].stop(),
nodes[1].stop()) nodes[1].stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut.concat())
await allFuturesThrowing(awaiters)
check: waitFor(runTests())
waitFor(runTests()) == true
test "GossipSub validation one fails and one succeeds": test "GossipSub validation one fails and one succeeds":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var handlerFut = newFuture[bool]() var handlerFut = newFuture[bool]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foo" check topic == "foo"
handlerFut.complete(true) handlerFut.complete(true)
var nodes = generateNodes(2, true) let
var awaiters: seq[Future[void]] nodes = generateNodes(2, gossip = true)
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start())) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
let subscribes = await subscribeNodes(nodes)
await nodes[1].subscribe("foo", handler) await nodes[1].subscribe("foo", handler)
await nodes[1].subscribe("bar", handler) await nodes[1].subscribe("bar", handler)
@ -222,10 +261,11 @@ suite "GossipSub":
tryPublish await nodes[0].publish("foo", "Hello!".toBytes()), 1 tryPublish await nodes[0].publish("foo", "Hello!".toBytes()), 1
tryPublish await nodes[0].publish("bar", "Hello!".toBytes()), 1 tryPublish await nodes[0].publish("bar", "Hello!".toBytes()), 1
result = ((await passed) and (await failed) and (await handlerFut)) check ((await passed) and (await failed) and (await handlerFut))
let gossip1 = GossipSub(nodes[0])
let gossip2 = GossipSub(nodes[1])
let gossip1 = GossipSub(nodes[0].pubSub.get())
let gossip2 = GossipSub(nodes[1].pubSub.get())
check: check:
"foo" notin gossip1.mesh and gossip1.fanout["foo"].len == 1 "foo" notin gossip1.mesh and gossip1.fanout["foo"].len == 1
"foo" notin gossip2.mesh and "foo" notin gossip2.fanout "foo" notin gossip2.mesh and "foo" notin gossip2.fanout
@ -233,104 +273,95 @@ suite "GossipSub":
"bar" notin gossip2.mesh and "bar" notin gossip2.fanout "bar" notin gossip2.mesh and "bar" notin gossip2.fanout
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].switch.stop(),
nodes[1].stop()) nodes[1].switch.stop()
)
await allFuturesThrowing(subscribes)
await allFuturesThrowing(awaiters)
result = true
check:
waitFor(runTests()) == true
test "GossipSub publish should fail on timeout":
proc runTests(): Future[bool] {.async.} =
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
discard
var nodes = generateNodes(2, gossip = true)
var awaiters: seq[Future[void]]
awaiters.add((await nodes[0].start()))
awaiters.add((await nodes[1].start()))
let subscribes = await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar")
let pubsub = nodes[0].pubSub.get()
let peer = pubsub.peers[nodes[1].peerInfo.id]
peer.conn = Connection(newBufferStream(
proc (data: seq[byte]) {.async, gcsafe.} =
await sleepAsync(10.seconds)
, size = 0))
let in10millis = Moment.fromNow(10.millis)
let sent = await nodes[0].publish("foobar", "Hello!".toBytes(), 10.millis)
check Moment.now() >= in10millis
check sent == 0
await allFuturesThrowing( await allFuturesThrowing(
nodes[0].stop(), nodes[0].stop(),
nodes[1].stop()) nodes[1].stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut.concat())
await allFuturesThrowing(awaiters)
result = true
check: waitFor(runTests())
waitFor(runTests()) == true
test "e2e - GossipSub should add remote peer topic subscriptions": test "e2e - GossipSub should add remote peer topic subscriptions":
proc testBasicGossipSub(): Future[bool] {.async.} = proc testBasicGossipSub() {.async.} =
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
discard discard
var nodes: seq[Switch] = newSeq[Switch]() let
for i in 0..<2: nodes = generateNodes(
nodes.add newStandardSwitch(gossip = true, 2,
secureManagers = [SecureProtocol.Noise]) gossip = true,
secureManagers = [SecureProtocol.Noise])
var awaitters: seq[Future[void]] # start switches
for node in nodes: nodesFut = await allFinished(
awaitters.add(await node.start()) nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
let subscribes = await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await sleepAsync(10.seconds) await sleepAsync(10.seconds)
let gossip1 = GossipSub(nodes[0].pubSub.get()) let gossip1 = GossipSub(nodes[0])
let gossip2 = GossipSub(nodes[1].pubSub.get()) let gossip2 = GossipSub(nodes[1])
check: check:
"foobar" in gossip2.topics "foobar" in gossip2.topics
"foobar" in gossip1.gossipsub "foobar" in gossip1.gossipsub
gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.peerId)
await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(
await allFuturesThrowing(awaitters) nodes[0].stop(),
nodes[1].stop()
)
result = true await allFuturesThrowing(nodesFut.concat())
check: waitFor(testBasicGossipSub())
waitFor(testBasicGossipSub()) == true
test "e2e - GossipSub should add remote peer topic subscriptions if both peers are subscribed": test "e2e - GossipSub should add remote peer topic subscriptions if both peers are subscribed":
proc testBasicGossipSub(): Future[bool] {.async.} = proc testBasicGossipSub() {.async.} =
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
discard discard
var nodes: seq[Switch] = newSeq[Switch]() let
for i in 0..<2: nodes = generateNodes(
nodes.add newStandardSwitch(gossip = true, secureManagers = [SecureProtocol.Secio]) 2,
gossip = true,
secureManagers = [SecureProtocol.Secio])
var awaitters: seq[Future[void]] # start switches
for node in nodes: nodesFut = await allFinished(
awaitters.add(await node.start()) nodes[0].switch.start(),
nodes[1].switch.start(),
)
let subscribes = await subscribeNodes(nodes) # start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[0].subscribe("foobar", handler) await nodes[0].subscribe("foobar", handler)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
@ -342,8 +373,8 @@ suite "GossipSub":
await allFuturesThrowing(subs) await allFuturesThrowing(subs)
let let
gossip1 = GossipSub(nodes[0].pubSub.get()) gossip1 = GossipSub(nodes[0])
gossip2 = GossipSub(nodes[1].pubSub.get()) gossip2 = GossipSub(nodes[1])
check: check:
"foobar" in gossip1.topics "foobar" in gossip1.topics
@ -352,35 +383,53 @@ suite "GossipSub":
"foobar" in gossip1.gossipsub "foobar" in gossip1.gossipsub
"foobar" in gossip2.gossipsub "foobar" in gossip2.gossipsub
gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) or gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.peerId) or
gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id) gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.peerId)
gossip2.gossipsub.hasPeerID("foobar", gossip1.peerInfo.id) or gossip2.gossipsub.hasPeerID("foobar", gossip1.peerInfo.peerId) or
gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.id) gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.peerId)
await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(
nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(
await allFuturesThrowing(awaitters) nodes[0].stop(),
nodes[1].stop()
)
result = true await allFuturesThrowing(nodesFut.concat())
check: waitFor(testBasicGossipSub())
waitFor(testBasicGossipSub()) == true
test "e2e - GossipSub send over fanout A -> B": test "e2e - GossipSub send over fanout A -> B":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var passed = newFuture[void]() var passed = newFuture[void]()
proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} =
check topic == "foobar" check topic == "foobar"
passed.complete() passed.complete()
var nodes = generateNodes(2, true) let
var wait = newSeq[Future[void]]() nodes = generateNodes(
wait.add(await nodes[0].start()) 2,
wait.add(await nodes[1].start()) gossip = true,
secureManagers = [SecureProtocol.Secio])
let subscribes = await subscribeNodes(nodes) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
await waitSub(nodes[0], nodes[1], "foobar") await waitSub(nodes[0], nodes[1], "foobar")
@ -393,18 +442,19 @@ suite "GossipSub":
obs2 = PubSubObserver(onSend: proc(peer: PubSubPeer; msgs: var RPCMsg) = obs2 = PubSubObserver(onSend: proc(peer: PubSubPeer; msgs: var RPCMsg) =
inc observed inc observed
) )
nodes[1].pubsub.get().addObserver(obs1)
nodes[0].pubsub.get().addObserver(obs2) # nodes[1].addObserver(obs1)
# nodes[0].addObserver(obs2)
tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1
var gossip1: GossipSub = GossipSub(nodes[0].pubSub.get()) var gossip1: GossipSub = GossipSub(nodes[0])
var gossip2: GossipSub = GossipSub(nodes[1].pubSub.get()) var gossip2: GossipSub = GossipSub(nodes[1])
check: check:
"foobar" in gossip1.gossipsub "foobar" in gossip1.gossipsub
gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.id) gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.peerId)
not gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id) not gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.peerId)
await passed.wait(2.seconds) await passed.wait(2.seconds)
@ -413,14 +463,20 @@ suite "GossipSub":
await nodes[0].stop() await nodes[0].stop()
await nodes[1].stop() await nodes[1].stop()
await allFuturesThrowing(subscribes) await allFuturesThrowing(
await allFuturesThrowing(wait) nodes[0].switch.stop(),
nodes[1].switch.stop()
)
check observed == 2 await allFuturesThrowing(
result = true nodes[0].stop(),
nodes[1].stop()
)
check: await allFuturesThrowing(nodesFut.concat())
waitFor(runTests()) == true # check observed == 2
waitFor(runTests())
test "e2e - GossipSub send over mesh A -> B": test "e2e - GossipSub send over mesh A -> B":
proc runTests(): Future[bool] {.async.} = proc runTests(): Future[bool] {.async.} =
@ -429,16 +485,26 @@ suite "GossipSub":
check topic == "foobar" check topic == "foobar"
passed.complete(true) passed.complete(true)
var nodes = generateNodes(2, true) let
var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get()) nodes = generateNodes(
gossipSub1.parameters.floodPublish = false 2,
var gossipSub2: GossipSub = GossipSub(nodes[1].pubSub.get()) gossip = true,
gossipSub2.parameters.floodPublish = false secureManagers = [SecureProtocol.Secio])
var wait: seq[Future[void]]
wait.add(await nodes[0].start())
wait.add(await nodes[1].start())
let subscribes = await subscribeNodes(nodes) # start switches
nodesFut = await allFinished(
nodes[0].switch.start(),
nodes[1].switch.start(),
)
# start pubsub
await allFuturesThrowing(
allFinished(
nodes[0].start(),
nodes[1].start(),
))
await subscribeNodes(nodes)
await nodes[0].subscribe("foobar", handler) await nodes[0].subscribe("foobar", handler)
await nodes[1].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler)
@ -448,41 +514,42 @@ suite "GossipSub":
result = await passed result = await passed
var gossip1: GossipSub = GossipSub(nodes[0].pubSub.get()) var gossip1: GossipSub = GossipSub(nodes[0])
var gossip2: GossipSub = GossipSub(nodes[1].pubSub.get()) var gossip2: GossipSub = GossipSub(nodes[1])
check: check:
"foobar" in gossip1.gossipsub "foobar" in gossip1.gossipsub
"foobar" in gossip2.gossipsub "foobar" in gossip2.gossipsub
gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id) gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.peerId)
not gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.id) not gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.peerId)
gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.id) gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.peerId)
not gossip2.fanout.hasPeerID("foobar", gossip1.peerInfo.id) not gossip2.fanout.hasPeerID("foobar", gossip1.peerInfo.peerId)
await nodes[0].stop() await allFuturesThrowing(
await nodes[1].stop() nodes[0].switch.stop(),
nodes[1].switch.stop()
)
await allFuturesThrowing(subscribes) await allFuturesThrowing(
await allFuturesThrowing(wait) nodes[0].stop(),
nodes[1].stop()
)
await allFuturesThrowing(nodesFut.concat())
check: check:
waitFor(runTests()) == true waitFor(runTests()) == true
test "e2e - GossipSub with multiple peers": test "e2e - GossipSub with multiple peers":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var nodes: seq[Switch] = newSeq[Switch]()
var awaitters: seq[Future[void]]
var runs = 10 var runs = 10
for i in 0..<runs: let
nodes.add newStandardSwitch(triggerSelf = true, nodes = generateNodes(runs, gossip = true, triggerSelf = true)
gossip = true, nodesFut = nodes.mapIt(it.switch.start())
secureManagers = [SecureProtocol.Noise])
var gossipSub = GossipSub(nodes[i].pubSub.get())
gossipSub.parameters.floodPublish = false
awaitters.add((await nodes[i].start()))
let subscribes = await subscribeRandom(nodes) await allFuturesThrowing(nodes.mapIt(it.start()))
await subscribeNodes(nodes)
var seen: Table[string, int] var seen: Table[string, int]
var subs: seq[Future[void]] var subs: seq[Future[void]]
@ -514,36 +581,33 @@ suite "GossipSub":
check: v >= 1 check: v >= 1
for node in nodes: for node in nodes:
var gossip: GossipSub = GossipSub(node.pubSub.get()) var gossip = GossipSub(node)
check: check:
"foobar" in gossip.gossipsub "foobar" in gossip.gossipsub
gossip.fanout.len == 0 gossip.fanout.len == 0
gossip.mesh["foobar"].len > 0 gossip.mesh["foobar"].len > 0
await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(
nodes.mapIt(
allFutures(
it.stop(),
it.switch.stop())))
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut)
await allFuturesThrowing(awaitters)
result = true
check: waitFor(runTests())
waitFor(runTests()) == true
test "e2e - GossipSub with multiple peers (sparse)": test "e2e - GossipSub with multiple peers (sparse)":
proc runTests(): Future[bool] {.async.} = proc runTests() {.async.} =
var nodes: seq[Switch] = newSeq[Switch]()
var awaitters: seq[Future[void]]
var runs = 10 var runs = 10
for i in 0..<runs: let
nodes.add newStandardSwitch(triggerSelf = true, nodes = generateNodes(runs, gossip = true, triggerSelf = true)
gossip = true, nodesFut = nodes.mapIt(it.switch.start())
secureManagers = [SecureProtocol.Secio])
var gossipSub = GossipSub(nodes[i].pubSub.get())
gossipSub.parameters.floodPublish = false
awaitters.add((await nodes[i].start()))
let subscribes = await subscribeSparseNodes(nodes, 1) await allFuturesThrowing(nodes.mapIt(it.start()))
await subscribeNodes(nodes)
var seen: Table[string, int] var seen: Table[string, int]
var subs: seq[Future[void]] var subs: seq[Future[void]]
@ -576,17 +640,18 @@ suite "GossipSub":
check: v >= 1 check: v >= 1
for node in nodes: for node in nodes:
var gossip: GossipSub = GossipSub(node.pubSub.get()) var gossip = GossipSub(node)
check: check:
"foobar" in gossip.gossipsub "foobar" in gossip.gossipsub
gossip.fanout.len == 0 gossip.fanout.len == 0
gossip.mesh["foobar"].len > 0 gossip.mesh["foobar"].len > 0
await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(
nodes.mapIt(
allFutures(
it.stop(),
it.switch.stop())))
await allFuturesThrowing(subscribes) await allFuturesThrowing(nodesFut)
await allFuturesThrowing(awaitters)
result = true
check: waitFor(runTests())
waitFor(runTests()) == true

View File

@ -16,4 +16,4 @@ suite "Message":
peer = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) peer = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get())
msg = Message.init(peer, @[], "topic", seqno, sign = true) msg = Message.init(peer, @[], "topic", seqno, sign = true)
check verify(msg, peer) check verify(msg, peer.peerId)

View File

@ -1,27 +1,65 @@
import random, options # compile time options here
const
libp2p_pubsub_sign {.booldefine.} = true
libp2p_pubsub_verify {.booldefine.} = true
import random
import chronos import chronos
import ../../libp2p/standard_setup import ../../libp2p/[standard_setup,
import ../../libp2p/protocols/pubsub/gossipsub protocols/pubsub/pubsub,
protocols/pubsub/floodsub,
protocols/pubsub/gossipsub,
protocols/secure/secure]
export standard_setup export standard_setup
randomize() randomize()
proc generateNodes*(num: Natural, gossip: bool = false): seq[Switch] = proc generateNodes*(
for i in 0..<num: num: Natural,
var switch = newStandardSwitch(gossip = gossip) secureManagers: openarray[SecureProtocol] = [
if gossip: # array cos order matters
var gossipSub = GossipSub(switch.pubSub.get()) SecureProtocol.Secio,
gossipSub.parameters.floodPublish = false SecureProtocol.Noise,
result.add(switch) ],
msgIdProvider: MsgIdProvider = nil,
gossip: bool = false,
triggerSelf: bool = false,
verifySignature: bool = libp2p_pubsub_verify,
sign: bool = libp2p_pubsub_sign): seq[PubSub] =
proc subscribeNodes*(nodes: seq[Switch]): Future[seq[Future[void]]] {.async.} = for i in 0..<num:
let switch = newStandardSwitch(secureManagers = secureManagers)
let pubsub = if gossip:
GossipSub.init(
switch = switch,
triggerSelf = triggerSelf,
verifySignature = verifySignature,
sign = sign,
msgIdProvider = msgIdProvider,
parameters = (
let p = GossipSubParams.init()
p.floodPublish = false
p)).PubSub
else:
FloodSub.init(
switch = switch,
triggerSelf = triggerSelf,
verifySignature = verifySignature,
sign = sign,
msgIdProvider = msgIdProvider).PubSub
switch.mount(pubsub)
result.add(pubsub)
proc subscribeNodes*(nodes: seq[PubSub]) {.async.} =
for dialer in nodes: for dialer in nodes:
for node in nodes: for node in nodes:
if dialer.peerInfo.peerId != node.peerInfo.peerId: if dialer.switch.peerInfo.peerId != node.switch.peerInfo.peerId:
await dialer.connect(node.peerInfo) await dialer.switch.connect(node.peerInfo.peerId, node.peerInfo.addrs)
result.add(dialer.subscribePeer(node.peerInfo)) dialer.subscribePeer(node.peerInfo.peerId)
proc subscribeSparseNodes*(nodes: seq[Switch], degree: int = 2): Future[seq[Future[void]]] {.async.} = proc subscribeSparseNodes*(nodes: seq[PubSub], degree: int = 2) {.async.} =
if nodes.len < degree: if nodes.len < degree:
raise (ref CatchableError)(msg: "nodes count needs to be greater or equal to degree!") raise (ref CatchableError)(msg: "nodes count needs to be greater or equal to degree!")
@ -30,17 +68,17 @@ proc subscribeSparseNodes*(nodes: seq[Switch], degree: int = 2): Future[seq[Futu
continue continue
for node in nodes: for node in nodes:
if dialer.peerInfo.peerId != node.peerInfo.peerId: if dialer.switch.peerInfo.peerId != node.peerInfo.peerId:
await dialer.connect(node.peerInfo) await dialer.switch.connect(node.peerInfo.peerId, node.peerInfo.addrs)
result.add(dialer.subscribePeer(node.peerInfo)) dialer.subscribePeer(node.peerInfo.peerId)
proc subscribeRandom*(nodes: seq[Switch]): Future[seq[Future[void]]] {.async.} = proc subscribeRandom*(nodes: seq[PubSub]) {.async.} =
for dialer in nodes: for dialer in nodes:
var dialed: seq[string] var dialed: seq[PeerID]
while dialed.len < nodes.len - 1: while dialed.len < nodes.len - 1:
let node = sample(nodes) let node = sample(nodes)
if node.peerInfo.id notin dialed: if node.peerInfo.peerId notin dialed:
if dialer.peerInfo.id != node.peerInfo.id: if dialer.peerInfo.peerId != node.peerInfo.peerId:
await dialer.connect(node.peerInfo) await dialer.switch.connect(node.peerInfo.peerId, node.peerInfo.addrs)
result.add(dialer.subscribePeer(node.peerInfo)) dialer.subscribePeer(node.peerInfo.peerId)
dialed.add(node.peerInfo.id) dialed.add(node.peerInfo.peerId)

View File

@ -72,11 +72,20 @@ proc testPubSubDaemonPublish(gossip: bool = false,
let daemonNode = await newDaemonApi(flags) let daemonNode = await newDaemonApi(flags)
let daemonPeer = await daemonNode.identity() let daemonPeer = await daemonNode.identity()
let nativeNode = newStandardSwitch( let nativeNode = newStandardSwitch(
gossip = gossip,
secureManagers = [SecureProtocol.Noise], secureManagers = [SecureProtocol.Noise],
outTimeout = 5.minutes) outTimeout = 5.minutes)
let pubsub = if gossip:
GossipSub.init(
switch = nativeNode).PubSub
else:
FloodSub.init(
switch = nativeNode).PubSub
nativeNode.mount(pubsub)
let awaiters = nativeNode.start() let awaiters = nativeNode.start()
await pubsub.start()
let nativePeer = nativeNode.peerInfo let nativePeer = nativeNode.peerInfo
var finished = false var finished = false
@ -91,8 +100,8 @@ proc testPubSubDaemonPublish(gossip: bool = false,
let peer = NativePeerInfo.init( let peer = NativePeerInfo.init(
daemonPeer.peer, daemonPeer.peer,
daemonPeer.addresses) daemonPeer.addresses)
await nativeNode.connect(peer) await nativeNode.connect(peer.peerId, peer.addrs)
let subscribeHanle = nativeNode.subscribePeer(peer) pubsub.subscribePeer(peer.peerId)
await sleepAsync(1.seconds) await sleepAsync(1.seconds)
await daemonNode.connect(nativePeer.peerId, nativePeer.addrs) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs)
@ -103,7 +112,7 @@ proc testPubSubDaemonPublish(gossip: bool = false,
result = true # don't cancel subscription result = true # don't cancel subscription
asyncDiscard daemonNode.pubsubSubscribe(testTopic, pubsubHandler) asyncDiscard daemonNode.pubsubSubscribe(testTopic, pubsubHandler)
await nativeNode.subscribe(testTopic, nativeHandler) await pubsub.subscribe(testTopic, nativeHandler)
await sleepAsync(5.seconds) await sleepAsync(5.seconds)
proc publisher() {.async.} = proc publisher() {.async.} =
@ -115,9 +124,9 @@ proc testPubSubDaemonPublish(gossip: bool = false,
result = true result = true
await nativeNode.stop() await nativeNode.stop()
await pubsub.stop()
await allFutures(awaiters) await allFutures(awaiters)
await daemonNode.close() await daemonNode.close()
await subscribeHanle
proc testPubSubNodePublish(gossip: bool = false, proc testPubSubNodePublish(gossip: bool = false,
count: int = 1): Future[bool] {.async.} = count: int = 1): Future[bool] {.async.} =
@ -132,18 +141,27 @@ proc testPubSubNodePublish(gossip: bool = false,
let daemonNode = await newDaemonApi(flags) let daemonNode = await newDaemonApi(flags)
let daemonPeer = await daemonNode.identity() let daemonPeer = await daemonNode.identity()
let nativeNode = newStandardSwitch( let nativeNode = newStandardSwitch(
gossip = gossip,
secureManagers = [SecureProtocol.Secio], secureManagers = [SecureProtocol.Secio],
outTimeout = 5.minutes) outTimeout = 5.minutes)
let pubsub = if gossip:
GossipSub.init(
switch = nativeNode).PubSub
else:
FloodSub.init(
switch = nativeNode).PubSub
nativeNode.mount(pubsub)
let awaiters = nativeNode.start() let awaiters = nativeNode.start()
await pubsub.start()
let nativePeer = nativeNode.peerInfo let nativePeer = nativeNode.peerInfo
let peer = NativePeerInfo.init( let peer = NativePeerInfo.init(
daemonPeer.peer, daemonPeer.peer,
daemonPeer.addresses) daemonPeer.addresses)
await nativeNode.connect(peer) await nativeNode.connect(peer)
let subscribeHandle = nativeNode.subscribePeer(peer) pubsub.subscribePeer(peer.peerId)
await sleepAsync(1.seconds) await sleepAsync(1.seconds)
await daemonNode.connect(nativePeer.peerId, nativePeer.addrs) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs)
@ -162,21 +180,21 @@ proc testPubSubNodePublish(gossip: bool = false,
discard await daemonNode.pubsubSubscribe(testTopic, pubsubHandler) discard await daemonNode.pubsubSubscribe(testTopic, pubsubHandler)
proc nativeHandler(topic: string, data: seq[byte]) {.async.} = discard proc nativeHandler(topic: string, data: seq[byte]) {.async.} = discard
await nativeNode.subscribe(testTopic, nativeHandler) await pubsub.subscribe(testTopic, nativeHandler)
await sleepAsync(5.seconds) await sleepAsync(5.seconds)
proc publisher() {.async.} = proc publisher() {.async.} =
while not finished: while not finished:
discard await nativeNode.publish(testTopic, msgData) discard await pubsub.publish(testTopic, msgData)
await sleepAsync(500.millis) await sleepAsync(500.millis)
await wait(publisher(), 5.minutes) # should be plenty of time await wait(publisher(), 5.minutes) # should be plenty of time
result = finished result = finished
await nativeNode.stop() await nativeNode.stop()
await pubsub.stop()
await allFutures(awaiters) await allFutures(awaiters)
await daemonNode.close() await daemonNode.close()
await subscribeHandle
suite "Interop": suite "Interop":
# TODO: chronos transports are leaking, # TODO: chronos transports are leaking,

214
tests/testminasn1.nim Normal file
View File

@ -0,0 +1,214 @@
## Nim-Libp2p
## Copyright (c) 2018 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import unittest
import ../libp2p/crypto/minasn1
import nimcrypto/utils as ncrutils
when defined(nimHasUsed): {.used.}
const Asn1EdgeValues = [
0'u64, (1'u64 shl 7) - 1'u64,
(1'u64 shl 7), (1'u64 shl 8) - 1'u64,
(1'u64 shl 8), (1'u64 shl 16) - 1'u64,
(1'u64 shl 16), (1'u64 shl 24) - 1'u64,
(1'u64 shl 24), (1'u64 shl 32) - 1'u64,
(1'u64 shl 32), (1'u64 shl 40) - 1'u64,
(1'u64 shl 40), (1'u64 shl 48) - 1'u64,
(1'u64 shl 48), (1'u64 shl 56) - 1'u64,
(1'u64 shl 56), 0xFFFF_FFFF_FFFF_FFFF'u64
]
const Asn1EdgeExpects = [
"00", "7F",
"8180", "81FF",
"820100", "82FFFF",
"83010000", "83FFFFFF",
"8401000000", "84FFFFFFFF",
"850100000000", "85FFFFFFFFFF",
"86010000000000", "86FFFFFFFFFFFF",
"8701000000000000", "87FFFFFFFFFFFFFF",
"880100000000000000", "88FFFFFFFFFFFFFFFF",
]
const Asn1UIntegerValues8 = [
0x00'u8, 0x7F'u8, 0x80'u8, 0xFF'u8,
]
const Asn1UIntegerExpects8 = [
"020100", "02017F", "02020080", "020200FF"
]
const Asn1UIntegerValues16 = [
0x00'u16, 0x7F'u16, 0x80'u16, 0xFF'u16,
0x7FFF'u16, 0x8000'u16, 0xFFFF'u16
]
const Asn1UIntegerExpects16 = [
"020100", "02017F", "02020080", "020200FF", "02027FFF",
"0203008000", "020300FFFF"
]
const Asn1UIntegerValues32 = [
0x00'u32, 0x7F'u32, 0x80'u32, 0xFF'u32,
0x7FFF'u32, 0x8000'u32, 0xFFFF'u32,
0x7FFF_FFFF'u32, 0x8000_0000'u32, 0xFFFF_FFFF'u32
]
const Asn1UIntegerExpects32 = [
"020100", "02017F", "02020080", "020200FF", "02027FFF",
"0203008000", "020300FFFF", "02047FFFFFFF", "02050080000000",
"020500FFFFFFFF"
]
const Asn1UIntegerValues64 = [
0x00'u64, 0x7F'u64, 0x80'u64, 0xFF'u64,
0x7FFF'u64, 0x8000'u64, 0xFFFF'u64,
0x7FFF_FFFF'u64, 0x8000_0000'u64, 0xFFFF_FFFF'u64,
0x7FFF_FFFF_FFFF_FFFF'u64, 0x8000_0000_0000_0000'u64,
0xFFFF_FFFF_FFFF_FFFF'u64
]
const Asn1UIntegerExpects64 = [
"020100", "02017F", "02020080", "020200FF", "02027FFF",
"0203008000", "020300FFFF", "02047FFFFFFF", "02050080000000",
"020500FFFFFFFF", "02087FFFFFFFFFFFFFFF", "0209008000000000000000",
"020900FFFFFFFFFFFFFFFF"
]
suite "Minimal ASN.1 encode/decode suite":
test "Length encoding edge values":
var empty = newSeq[byte](0)
for i in 0 ..< len(Asn1EdgeValues):
var value = newSeq[byte](9)
let r1 = asn1EncodeLength(empty, Asn1EdgeValues[i])
let r2 = asn1EncodeLength(value, Asn1EdgeValues[i])
value.setLen(r2)
check:
r1 == (len(Asn1EdgeExpects[i]) shr 1)
r2 == (len(Asn1EdgeExpects[i]) shr 1)
check:
ncrutils.fromHex(Asn1EdgeExpects[i]) == value
test "ASN.1 DER INTEGER encoding/decoding of native unsigned values test":
proc decodeBuffer(data: openarray[byte]): uint64 =
var ab = Asn1Buffer.init(data)
let fres = ab.read()
doAssert(fres.isOk() and fres.get().kind == Asn1Tag.Integer)
fres.get().vint
proc encodeInteger[T](value: T): seq[byte] =
var buffer = newSeq[byte](16)
let res = asn1EncodeInteger(buffer, value)
buffer.setLen(res)
buffer
for i in 0 ..< len(Asn1UIntegerValues8):
let buffer = encodeInteger(Asn1UIntegerValues8[i])
check:
toHex(buffer) == Asn1UIntegerExpects8[i]
decodeBuffer(buffer) == uint64(Asn1UIntegerValues8[i])
for i in 0 ..< len(Asn1UIntegerValues16):
let buffer = encodeInteger(Asn1UIntegerValues16[i])
check:
toHex(buffer) == Asn1UIntegerExpects16[i]
decodeBuffer(buffer) == uint64(Asn1UIntegerValues16[i])
for i in 0 ..< len(Asn1UIntegerValues32):
let buffer = encodeInteger(Asn1UIntegerValues32[i])
check:
toHex(buffer) == Asn1UIntegerExpects32[i]
decodeBuffer(buffer) == uint64(Asn1UIntegerValues32[i])
for i in 0 ..< len(Asn1UIntegerValues64):
let buffer = encodeInteger(Asn1UIntegerValues64[i])
check:
toHex(buffer) == Asn1UIntegerExpects64[i]
decodeBuffer(buffer) == uint64(Asn1UIntegerValues64[i])
test "ASN.1 DER INTEGER incorrect values decoding test":
proc decodeBuffer(data: string): Asn1Result[Asn1Field] =
var ab = Asn1Buffer.init(fromHex(data))
ab.read()
check:
decodeBuffer("0200").error == Asn1Error.Incorrect
decodeBuffer("0201").error == Asn1Error.Incomplete
decodeBuffer("02020000").error == Asn1Error.Incorrect
decodeBuffer("0203000001").error == Asn1Error.Incorrect
test "ASN.1 DER BITSTRING encoding/decoding with unused bits test":
proc encodeBits(value: string, bitsUsed: int): seq[byte] =
var buffer = newSeq[byte](16)
let res = asn1EncodeBitString(buffer, fromHex(value), bitsUsed)
buffer.setLen(res)
buffer
proc decodeBuffer(data: string): Asn1Field =
var ab = Asn1Buffer.init(fromHex(data))
let fres = ab.read()
doAssert(fres.isOk() and fres.get().kind == Asn1Tag.BitString)
fres.get()
check:
toHex(encodeBits("FF", 7)) == "03020780"
toHex(encodeBits("FF", 6)) == "030206C0"
toHex(encodeBits("FF", 5)) == "030205E0"
toHex(encodeBits("FF", 4)) == "030204F0"
toHex(encodeBits("FF", 3)) == "030203F8"
toHex(encodeBits("FF", 2)) == "030202FC"
toHex(encodeBits("FF", 1)) == "030201FE"
toHex(encodeBits("FF", 0)) == "030200FF"
let f0 = decodeBuffer("030200FF")
let f0b = @(f0.buffer.toOpenArray(f0.offset, f0.offset + f0.length - 1))
let f1 = decodeBuffer("030201FE")
let f1b = @(f1.buffer.toOpenArray(f1.offset, f1.offset + f1.length - 1))
let f2 = decodeBuffer("030202FC")
let f2b = @(f2.buffer.toOpenArray(f2.offset, f2.offset + f2.length - 1))
let f3 = decodeBuffer("030203F8")
let f3b = @(f3.buffer.toOpenArray(f3.offset, f3.offset + f3.length - 1))
let f4 = decodeBuffer("030204F0")
let f4b = @(f4.buffer.toOpenArray(f4.offset, f4.offset + f4.length - 1))
let f5 = decodeBuffer("030205E0")
let f5b = @(f5.buffer.toOpenArray(f5.offset, f5.offset + f5.length - 1))
let f6 = decodeBuffer("030206C0")
let f6b = @(f6.buffer.toOpenArray(f6.offset, f6.offset + f6.length - 1))
let f7 = decodeBuffer("03020780")
let f7b = @(f7.buffer.toOpenArray(f7.offset, f7.offset + f7.length - 1))
check:
f0.ubits == 0
toHex(f0b) == "FF"
f1.ubits == 1
toHex(f1b) == "FE"
f2.ubits == 2
toHex(f2b) == "FC"
f3.ubits == 3
toHex(f3b) == "F8"
f4.ubits == 4
toHex(f4b) == "F0"
f5.ubits == 5
toHex(f5b) == "E0"
f6.ubits == 6
toHex(f6b) == "C0"
f7.ubits == 7
toHex(f7b) == "80"
test "ASN.1 DER BITSTRING incorrect values decoding test":
proc decodeBuffer(data: string): Asn1Result[Asn1Field] =
var ab = Asn1Buffer.init(fromHex(data))
ab.read()
check:
decodeBuffer("0300").error == Asn1Error.Incorrect
decodeBuffer("030180").error == Asn1Error.Incorrect
decodeBuffer("030107").error == Asn1Error.Incorrect
decodeBuffer("030200").error == Asn1Error.Incomplete
decodeBuffer("030208FF").error == Asn1Error.Incorrect

View File

@ -135,18 +135,20 @@ suite "Mplex":
let let
conn = newBufferStream( conn = newBufferStream(
proc (data: seq[byte]) {.gcsafe, async.} = proc (data: seq[byte]) {.gcsafe, async.} =
discard discard,
timeout = 5.minutes
) )
chann = LPChannel.init(1, conn, true) chann = LPChannel.init(1, conn, true)
await chann.pushTo(("Hello!").toBytes) await chann.pushTo(("Hello!").toBytes)
let closeFut = chann.closeRemote()
var data = newSeq[byte](6) var data = newSeq[byte](6)
await chann.readExactly(addr data[0], 6) # this should work, since there is data in the buffer await chann.readExactly(addr data[0], 3)
let closeFut = chann.closeRemote() # closing channel
let readFut = chann.readExactly(addr data[3], 3)
await all(closeFut, readFut)
try: try:
await chann.readExactly(addr data[0], 6) # this should throw await chann.readExactly(addr data[0], 6) # this should fail now
await closeFut
except LPStreamEOFError: except LPStreamEOFError:
result = true result = true
finally: finally:
@ -156,6 +158,29 @@ suite "Mplex":
check: check:
waitFor(testClosedForRead()) == true waitFor(testClosedForRead()) == true
test "half closed - channel should allow writting on remote close":
proc testClosedForRead(): Future[bool] {.async.} =
let
testData = "Hello!".toBytes
conn = newBufferStream(
proc (data: seq[byte]) {.gcsafe, async.} =
discard
, timeout = 5.minutes
)
chann = LPChannel.init(1, conn, true)
var data = newSeq[byte](6)
await chann.closeRemote() # closing channel
try:
await chann.writeLp(testData)
return true
finally:
await chann.close()
await conn.close()
check:
waitFor(testClosedForRead()) == true
test "should not allow pushing data to channel when remote end closed": test "should not allow pushing data to channel when remote end closed":
proc testResetWrite(): Future[bool] {.async.} = proc testResetWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
@ -211,20 +236,20 @@ suite "Mplex":
check: check:
waitFor(testResetWrite()) == true waitFor(testResetWrite()) == true
test "reset - channel should reset on timeout": test "reset - channel should reset on timeout":
proc testResetWrite(): Future[bool] {.async.} = proc testResetWrite(): Future[bool] {.async.} =
proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard
let let
conn = newBufferStream(writeHandler) conn = newBufferStream(writeHandler)
chann = LPChannel.init( chann = LPChannel.init(
1, conn, true, timeout = 100.millis) 1, conn, true, timeout = 100.millis)
await chann.closeEvent.wait() await chann.closeEvent.wait()
await conn.close() await conn.close()
result = true result = true
check: check:
waitFor(testResetWrite()) waitFor(testResetWrite())
test "e2e - read/write receiver": test "e2e - read/write receiver":
proc testNewStream() {.async.} = proc testNewStream() {.async.} =
@ -318,17 +343,23 @@ suite "Mplex":
bigseq.add(uint8(rand(uint('A')..uint('z')))) bigseq.add(uint8(rand(uint('A')..uint('z'))))
proc connHandler(conn: Connection) {.async, gcsafe.} = proc connHandler(conn: Connection) {.async, gcsafe.} =
let mplexListen = Mplex.init(conn) try:
mplexListen.streamHandler = proc(stream: Connection) let mplexListen = Mplex.init(conn)
{.async, gcsafe.} = mplexListen.streamHandler = proc(stream: Connection)
let msg = await stream.readLp(MaxMsgSize) {.async, gcsafe.} =
check msg == bigseq let msg = await stream.readLp(MaxMsgSize)
trace "Bigseq check passed!" check msg == bigseq
await stream.close() trace "Bigseq check passed!"
listenJob.complete() await stream.close()
listenJob.complete()
await mplexListen.handle() await mplexListen.handle()
await mplexListen.close() await sleepAsync(1.seconds) # give chronos some slack to process things
await mplexListen.close()
except CancelledError as exc:
raise exc
except CatchableError as exc:
check false
let transport1: TcpTransport = TcpTransport.init() let transport1: TcpTransport = TcpTransport.init()
let listenFut = await transport1.listen(ma, connHandler) let listenFut = await transport1.listen(ma, connHandler)

View File

@ -2,7 +2,8 @@ import testvarint,
testminprotobuf, testminprotobuf,
teststreamseq teststreamseq
import testrsa, import testminasn1,
testrsa,
testecnist, testecnist,
tested25519, tested25519,
testsecp256k1, testsecp256k1,

View File

@ -118,7 +118,7 @@ suite "Switch":
# plus 4 for the pubsub streams # plus 4 for the pubsub streams
check (BufferStreamTracker(bufferTracker).opened == check (BufferStreamTracker(bufferTracker).opened ==
(BufferStreamTracker(bufferTracker).closed + 4.uint64)) (BufferStreamTracker(bufferTracker).closed))
var connTracker = getTracker(ConnectionTrackerName) var connTracker = getTracker(ConnectionTrackerName)
# echo connTracker.dump() # echo connTracker.dump()
@ -127,7 +127,7 @@ suite "Switch":
# and the pubsub streams that won't clean up until # and the pubsub streams that won't clean up until
# `disconnect()` or `stop()` # `disconnect()` or `stop()`
check (ConnectionTracker(connTracker).opened == check (ConnectionTracker(connTracker).opened ==
(ConnectionTracker(connTracker).closed + 8.uint64)) (ConnectionTracker(connTracker).closed + 4.uint64))
await allFuturesThrowing( await allFuturesThrowing(
done.wait(5.seconds), done.wait(5.seconds),