diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 45aaca357..3edd2ff00 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -121,7 +121,7 @@ proc onClose(c: ConnManager, conn: Connection) {.async.} = ## triggers the connections resource cleanup ## - await conn.closeEvent.wait() + await conn.join() trace "triggering connection cleanup" await c.cleanupConn(conn) diff --git a/libp2p/crypto/crypto.nim b/libp2p/crypto/crypto.nim index 820bde77d..1a448c315 100644 --- a/libp2p/crypto/crypto.nim +++ b/libp2p/crypto/crypto.nim @@ -70,13 +70,15 @@ when supported(PKScheme.Secp256k1): import ecnist, bearssl import ../protobuf/minprotobuf, ../vbuffer, ../multihash, ../multicodec -import nimcrypto/[rijndael, twofish, sha2, hash, hmac, utils] +import nimcrypto/[rijndael, twofish, sha2, hash, hmac] +# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures. +import nimcrypto/utils as ncrutils import ../utility import stew/results export results # This is workaround for Nim's `import` bug -export rijndael, twofish, sha2, hash, hmac, utils +export rijndael, twofish, sha2, hash, hmac, ncrutils from strutils import split @@ -514,20 +516,14 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: string): bool = ## hexadecimal string representation. ## ## Returns ``true`` on success. - try: - key.init(utils.fromHex(data)) - except ValueError: - false + key.init(ncrutils.fromHex(data)) proc init*(sig: var Signature, data: string): bool = ## Initialize signature ``sig`` from serialized hexadecimal string ## representation. ## ## Returns ``true`` on success. - try: - sig.init(utils.fromHex(data)) - except ValueError: - false + sig.init(ncrutils.fromHex(data)) proc init*(t: typedesc[PrivateKey], data: openarray[byte]): CryptoResult[PrivateKey] = @@ -559,10 +555,7 @@ proc init*(t: typedesc[Signature], proc init*(t: typedesc[PrivateKey], data: string): CryptoResult[PrivateKey] = ## Create new private key from libp2p's protobuf serialized hexadecimal string ## form. - try: - t.init(utils.fromHex(data)) - except ValueError: - err(KeyError) + t.init(ncrutils.fromHex(data)) when supported(PKScheme.RSA): proc init*(t: typedesc[PrivateKey], key: rsa.RsaPrivateKey): PrivateKey = @@ -591,17 +584,11 @@ when supported(PKScheme.ECDSA): proc init*(t: typedesc[PublicKey], data: string): CryptoResult[PublicKey] = ## Create new public key from libp2p's protobuf serialized hexadecimal string ## form. - try: - t.init(utils.fromHex(data)) - except ValueError: - err(KeyError) + t.init(ncrutils.fromHex(data)) proc init*(t: typedesc[Signature], data: string): CryptoResult[Signature] = ## Create new signature from serialized hexadecimal string form. - try: - t.init(utils.fromHex(data)) - except ValueError: - err(SigError) + t.init(ncrutils.fromHex(data)) proc `==`*(key1, key2: PublicKey): bool {.inline.} = ## Return ``true`` if two public keys ``key1`` and ``key2`` of the same @@ -709,7 +696,7 @@ func shortLog*(key: PrivateKey|PublicKey): string = proc `$`*(sig: Signature): string = ## Get string representation of signature ``sig``. - result = toHex(sig.data) + result = ncrutils.toHex(sig.data) proc sign*(key: PrivateKey, data: openarray[byte]): CryptoResult[Signature] {.gcsafe.} = diff --git a/libp2p/crypto/ecnist.nim b/libp2p/crypto/ecnist.nim index ad8312613..8c3ebdce5 100644 --- a/libp2p/crypto/ecnist.nim +++ b/libp2p/crypto/ecnist.nim @@ -17,7 +17,8 @@ {.push raises: [Defect].} import bearssl -import nimcrypto/utils +# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures. +import nimcrypto/utils as ncrutils import minasn1 export minasn1.Asn1Error import stew/[results, ctops] @@ -289,7 +290,7 @@ proc `$`*(seckey: EcPrivateKey): string = result = "Corrupted key" else: let e = offset + cast[int](seckey.key.xlen) - 1 - result = toHex(seckey.buffer.toOpenArray(offset, e)) + result = ncrutils.toHex(seckey.buffer.toOpenArray(offset, e)) proc `$`*(pubkey: EcPublicKey): string = ## Return string representation of EC public key. @@ -305,14 +306,14 @@ proc `$`*(pubkey: EcPublicKey): string = result = "Corrupted key" else: let e = offset + cast[int](pubkey.key.qlen) - 1 - result = toHex(pubkey.buffer.toOpenArray(offset, e)) + result = ncrutils.toHex(pubkey.buffer.toOpenArray(offset, e)) proc `$`*(sig: EcSignature): string = ## Return hexadecimal string representation of EC signature. if isNil(sig) or len(sig.buffer) == 0: result = "Empty or uninitialized ECNIST signature" else: - result = toHex(sig.buffer) + result = ncrutils.toHex(sig.buffer) proc toRawBytes*(seckey: EcPrivateKey, data: var openarray[byte]): EcResult[int] = ## Serialize EC private key ``seckey`` to raw binary form and store it @@ -708,14 +709,16 @@ proc init*(sig: var EcSignature, data: openarray[byte]): Result[void, Asn1Error] else: err(Asn1Error.Incorrect) -proc init*[T: EcPKI](sospk: var T, data: string): Result[void, Asn1Error] {.inline.} = +proc init*[T: EcPKI](sospk: var T, + data: string): Result[void, Asn1Error] {.inline.} = ## Initialize EC `private key`, `public key` or `signature` ``sospk`` from ## ASN.1 DER hexadecimal string representation ``data``. ## ## Procedure returns ``Asn1Status``. - sospk.init(fromHex(data)) + sospk.init(ncrutils.fromHex(data)) -proc init*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPrivateKey] = +proc init*(t: typedesc[EcPrivateKey], + data: openarray[byte]): EcResult[EcPrivateKey] = ## Initialize EC private key from ASN.1 DER binary representation ``data`` and ## return constructed object. var key: EcPrivateKey @@ -725,7 +728,8 @@ proc init*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPrivate else: ok(key) -proc init*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPublicKey] = +proc init*(t: typedesc[EcPublicKey], + data: openarray[byte]): EcResult[EcPublicKey] = ## Initialize EC public key from ASN.1 DER binary representation ``data`` and ## return constructed object. var key: EcPublicKey @@ -735,7 +739,8 @@ proc init*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPublicKe else: ok(key) -proc init*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSignature] = +proc init*(t: typedesc[EcSignature], + data: openarray[byte]): EcResult[EcSignature] = ## Initialize EC signature from raw binary representation ``data`` and ## return constructed object. var sig: EcSignature @@ -748,10 +753,7 @@ proc init*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSignatur proc init*[T: EcPKI](t: typedesc[T], data: string): EcResult[T] = ## Initialize EC `private key`, `public key` or `signature` from hexadecimal ## string representation ``data`` and return constructed object. - try: - t.init(fromHex(data)) - except ValueError: - err(EcKeyIncorrectError) + t.init(ncrutils.fromHex(data)) proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool = ## Initialize EC `private key` or `scalar` ``key`` from raw binary @@ -833,9 +835,10 @@ proc initRaw*[T: EcPKI](sospk: var T, data: string): bool {.inline.} = ## raw hexadecimal string representation ``data``. ## ## Procedure returns ``true`` on success, ``false`` otherwise. - result = sospk.initRaw(fromHex(data)) + result = sospk.initRaw(ncrutils.fromHex(data)) -proc initRaw*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPrivateKey] = +proc initRaw*(t: typedesc[EcPrivateKey], + data: openarray[byte]): EcResult[EcPrivateKey] = ## Initialize EC private key from raw binary representation ``data`` and ## return constructed object. var res: EcPrivateKey @@ -844,7 +847,8 @@ proc initRaw*(t: typedesc[EcPrivateKey], data: openarray[byte]): EcResult[EcPriv else: ok(res) -proc initRaw*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPublicKey] = +proc initRaw*(t: typedesc[EcPublicKey], + data: openarray[byte]): EcResult[EcPublicKey] = ## Initialize EC public key from raw binary representation ``data`` and ## return constructed object. var res: EcPublicKey @@ -853,7 +857,8 @@ proc initRaw*(t: typedesc[EcPublicKey], data: openarray[byte]): EcResult[EcPubli else: ok(res) -proc initRaw*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSignature] = +proc initRaw*(t: typedesc[EcSignature], + data: openarray[byte]): EcResult[EcSignature] = ## Initialize EC signature from raw binary representation ``data`` and ## return constructed object. var res: EcSignature @@ -865,7 +870,7 @@ proc initRaw*(t: typedesc[EcSignature], data: openarray[byte]): EcResult[EcSigna proc initRaw*[T: EcPKI](t: typedesc[T], data: string): T {.inline.} = ## Initialize EC `private key`, `public key` or `signature` from raw ## hexadecimal string representation ``data`` and return constructed object. - result = t.initRaw(fromHex(data)) + result = t.initRaw(ncrutils.fromHex(data)) proc scalarMul*(pub: EcPublicKey, sec: EcPrivateKey): EcPublicKey = ## Return scalar multiplication of ``pub`` and ``sec``. @@ -926,7 +931,7 @@ proc getSecret*(pubkey: EcPublicKey, seckey: EcPrivateKey): seq[byte] = copyMem(addr result[0], addr data[0], res) proc sign*[T: byte|char](seckey: EcPrivateKey, - message: openarray[T]): EcResult[EcSignature] {.gcsafe.} = + message: openarray[T]): EcResult[EcSignature] {.gcsafe.} = ## Get ECDSA signature of data ``message`` using private key ``seckey``. if isNil(seckey): return err(EcKeyIncorrectError) diff --git a/libp2p/crypto/ed25519/ed25519.nim b/libp2p/crypto/ed25519/ed25519.nim index 3c16b7cf6..663cee97f 100644 --- a/libp2p/crypto/ed25519/ed25519.nim +++ b/libp2p/crypto/ed25519/ed25519.nim @@ -14,7 +14,9 @@ {.push raises: Defect.} import constants, bearssl -import nimcrypto/[hash, sha2, utils] +import nimcrypto/[hash, sha2] +# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures. +import nimcrypto/utils as ncrutils import stew/[results, ctops] export results @@ -1735,14 +1737,17 @@ proc `==`*(eda, edb: EdSignature): bool = ## Compare ED25519 `signature` objects for equality. result = CT.isEqual(eda.data, edb.data) -proc `$`*(key: EdPrivateKey): string = toHex(key.data) +proc `$`*(key: EdPrivateKey): string = ## Return string representation of ED25519 `private key`. + ncrutils.toHex(key.data) -proc `$`*(key: EdPublicKey): string = toHex(key.data) +proc `$`*(key: EdPublicKey): string = ## Return string representation of ED25519 `private key`. + ncrutils.toHex(key.data) -proc `$`*(sig: EdSignature): string = toHex(sig.data) +proc `$`*(sig: EdSignature): string = ## Return string representation of ED25519 `signature`. + ncrutils.toHex(sig.data) proc init*(key: var EdPrivateKey, data: openarray[byte]): bool = ## Initialize ED25519 `private key` ``key`` from raw binary @@ -1779,32 +1784,24 @@ proc init*(key: var EdPrivateKey, data: string): bool = ## representation ``data``. ## ## Procedure returns ``true`` on success. - try: - init(key, fromHex(data)) - except ValueError: - false + init(key, ncrutils.fromHex(data)) proc init*(key: var EdPublicKey, data: string): bool = ## Initialize ED25519 `public key` ``key`` from hexadecimal string ## representation ``data``. ## ## Procedure returns ``true`` on success. - try: - init(key, fromHex(data)) - except ValueError: - false + init(key, ncrutils.fromHex(data)) proc init*(sig: var EdSignature, data: string): bool = ## Initialize ED25519 `signature` ``sig`` from hexadecimal string ## representation ``data``. ## ## Procedure returns ``true`` on success. - try: - init(sig, fromHex(data)) - except ValueError: - false + init(sig, ncrutils.fromHex(data)) -proc init*(t: typedesc[EdPrivateKey], data: openarray[byte]): Result[EdPrivateKey, EdError] = +proc init*(t: typedesc[EdPrivateKey], + data: openarray[byte]): Result[EdPrivateKey, EdError] = ## Initialize ED25519 `private key` from raw binary representation ``data`` ## and return constructed object. var res: t @@ -1813,7 +1810,8 @@ proc init*(t: typedesc[EdPrivateKey], data: openarray[byte]): Result[EdPrivateKe else: ok(res) -proc init*(t: typedesc[EdPublicKey], data: openarray[byte]): Result[EdPublicKey, EdError] = +proc init*(t: typedesc[EdPublicKey], + data: openarray[byte]): Result[EdPublicKey, EdError] = ## Initialize ED25519 `public key` from raw binary representation ``data`` ## and return constructed object. var res: t @@ -1822,7 +1820,8 @@ proc init*(t: typedesc[EdPublicKey], data: openarray[byte]): Result[EdPublicKey, else: ok(res) -proc init*(t: typedesc[EdSignature], data: openarray[byte]): Result[EdSignature, EdError] = +proc init*(t: typedesc[EdSignature], + data: openarray[byte]): Result[EdSignature, EdError] = ## Initialize ED25519 `signature` from raw binary representation ``data`` ## and return constructed object. var res: t @@ -1831,7 +1830,8 @@ proc init*(t: typedesc[EdSignature], data: openarray[byte]): Result[EdSignature, else: ok(res) -proc init*(t: typedesc[EdPrivateKey], data: string): Result[EdPrivateKey, EdError] = +proc init*(t: typedesc[EdPrivateKey], + data: string): Result[EdPrivateKey, EdError] = ## Initialize ED25519 `private key` from hexadecimal string representation ## ``data`` and return constructed object. var res: t @@ -1840,7 +1840,8 @@ proc init*(t: typedesc[EdPrivateKey], data: string): Result[EdPrivateKey, EdErro else: ok(res) -proc init*(t: typedesc[EdPublicKey], data: string): Result[EdPublicKey, EdError] = +proc init*(t: typedesc[EdPublicKey], + data: string): Result[EdPublicKey, EdError] = ## Initialize ED25519 `public key` from hexadecimal string representation ## ``data`` and return constructed object. var res: t @@ -1849,7 +1850,8 @@ proc init*(t: typedesc[EdPublicKey], data: string): Result[EdPublicKey, EdError] else: ok(res) -proc init*(t: typedesc[EdSignature], data: string): Result[EdSignature, EdError] = +proc init*(t: typedesc[EdSignature], + data: string): Result[EdSignature, EdError] = ## Initialize ED25519 `signature` from hexadecimal string representation ## ``data`` and return constructed object. var res: t diff --git a/libp2p/crypto/minasn1.nim b/libp2p/crypto/minasn1.nim index 9363d2118..fbfd33f73 100644 --- a/libp2p/crypto/minasn1.nim +++ b/libp2p/crypto/minasn1.nim @@ -11,9 +11,10 @@ {.push raises: [Defect].} -import stew/[endians2, results] +import stew/[endians2, results, ctops] export results -import nimcrypto/utils +# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures. +import nimcrypto/utils as ncrutils type Asn1Error* {.pure.} = enum @@ -122,7 +123,7 @@ proc len*[T: Asn1Buffer|Asn1Composite](abc: T): int {.inline.} = len(abc.buffer) - abc.offset proc len*(field: Asn1Field): int {.inline.} = - result = field.length + field.length template getPtr*(field: untyped): pointer = cast[pointer](unsafeAddr field.buffer[field.offset]) @@ -153,30 +154,32 @@ proc code*(tag: Asn1Tag): byte {.inline.} = of Asn1Tag.Context: 0xA0'u8 -proc asn1EncodeLength*(dest: var openarray[byte], length: int64): int = +proc asn1EncodeLength*(dest: var openarray[byte], length: uint64): int = ## Encode ASN.1 DER length part of TLV triple and return number of bytes ## (octets) used. ## ## If length of ``dest`` is less then number of required bytes to encode - ## ``length`` value, then result of encoding will not be stored in ``dest`` + ## ``length`` value, then result of encoding WILL NOT BE stored in ``dest`` ## but number of bytes (octets) required will be returned. - if length < 0x80: + if length < 0x80'u64: if len(dest) >= 1: - dest[0] = cast[byte](length) - result = 1 + dest[0] = byte(length and 0x7F'u64) + 1 else: - result = 0 + var res = 1'u64 var z = length while z != 0: - inc(result) + inc(res) z = z shr 8 - if len(dest) >= result + 1: - dest[0] = cast[byte](0x80 + result) + if uint64(len(dest)) >= res: + dest[0] = byte((0x80'u64 + (res - 1'u64)) and 0xFF) var o = 1 - for j in countdown(result - 1, 0): - dest[o] = cast[byte](length shr (j shl 3)) + for j in countdown(res - 2, 0): + dest[o] = byte((length shr (j shl 3)) and 0xFF'u64) inc(o) - inc(result) + # Because our `length` argument is `uint64`, `res` could not be bigger + # then 9, so it is safe to convert it to `int`. + int(res) proc asn1EncodeInteger*(dest: var openarray[byte], value: openarray[byte]): int = @@ -184,35 +187,46 @@ proc asn1EncodeInteger*(dest: var openarray[byte], ## and return number of bytes (octets) used. ## ## If length of ``dest`` is less then number of required bytes to encode - ## ``value``, then result of encoding will not be stored in ``dest`` + ## ``value``, then result of encoding WILL NOT BE stored in ``dest`` ## but number of bytes (octets) required will be returned. var buffer: array[16, byte] - var o = 0 var lenlen = 0 - for i in 0.. 0: - if o == len(value): - dec(o) - if value[o] >= 0x80'u8: - lenlen = asn1EncodeLength(buffer, len(value) - o + 1) - result = 1 + lenlen + 1 + (len(value) - o) + + let offset = + block: + var o = 0 + for i in 0 ..< len(value): + if value[o] != 0x00: + break + inc(o) + if o < len(value): + o + else: + o - 1 + + let destlen = + if len(value) > 0: + if value[offset] >= 0x80'u8: + lenlen = asn1EncodeLength(buffer, uint64(len(value) - offset + 1)) + 1 + lenlen + 1 + (len(value) - offset) + else: + lenlen = asn1EncodeLength(buffer, uint64(len(value) - offset)) + 1 + lenlen + (len(value) - offset) else: - lenlen = asn1EncodeLength(buffer, len(value) - o) - result = 1 + lenlen + (len(value) - o) - else: - result = 2 - if len(dest) >= result: - var s = 1 + 2 + + if len(dest) >= destlen: + var shift = 1 dest[0] = Asn1Tag.Integer.code() copyMem(addr dest[1], addr buffer[0], lenlen) - if value[o] >= 0x80'u8: - dest[1 + lenlen] = 0x00'u8 - s = 2 - if len(value) > 0: - copyMem(addr dest[s + lenlen], unsafeAddr value[o], len(value) - o) + # If ``destlen > 2`` it means that ``len(value) > 0`` too. + if destlen > 2: + if value[offset] >= 0x80'u8: + dest[1 + lenlen] = 0x00'u8 + shift = 2 + copyMem(addr dest[shift + lenlen], unsafeAddr value[offset], + len(value) - offset) + destlen proc asn1EncodeInteger*[T: SomeUnsignedInt](dest: var openarray[byte], value: T): int = @@ -231,11 +245,12 @@ proc asn1EncodeBoolean*(dest: var openarray[byte], value: bool): int = ## If length of ``dest`` is less then number of required bytes to encode ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. - result = 3 - if len(dest) >= result: + let res = 3 + if len(dest) >= res: dest[0] = Asn1Tag.Boolean.code() dest[1] = 0x01'u8 dest[2] = if value: 0xFF'u8 else: 0x00'u8 + res proc asn1EncodeNull*(dest: var openarray[byte]): int = ## Encode ASN.1 DER `NULL` and return number of bytes (octets) used. @@ -243,13 +258,14 @@ proc asn1EncodeNull*(dest: var openarray[byte]): int = ## If length of ``dest`` is less then number of required bytes to encode ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. - result = 2 - if len(dest) >= result: + let res = 2 + if len(dest) >= res: dest[0] = Asn1Tag.Null.code() dest[1] = 0x00'u8 + res proc asn1EncodeOctetString*(dest: var openarray[byte], - value: openarray[byte]): int = + value: openarray[byte]): int = ## Encode array of bytes as ASN.1 DER `OCTET STRING` and return number of ## bytes (octets) used. ## @@ -257,38 +273,50 @@ proc asn1EncodeOctetString*(dest: var openarray[byte], ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. var buffer: array[16, byte] - var lenlen = asn1EncodeLength(buffer, len(value)) - result = 1 + lenlen + len(value) - if len(dest) >= result: + let lenlen = asn1EncodeLength(buffer, uint64(len(value))) + let res = 1 + lenlen + len(value) + if len(dest) >= res: dest[0] = Asn1Tag.OctetString.code() copyMem(addr dest[1], addr buffer[0], lenlen) if len(value) > 0: copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) + res proc asn1EncodeBitString*(dest: var openarray[byte], value: openarray[byte], bits = 0): int = ## Encode array of bytes as ASN.1 DER `BIT STRING` and return number of bytes ## (octets) used. ## - ## ``bits`` number of used bits in ``value``. If ``bits == 0``, all the bits - ## from ``value`` are used, if ``bits != 0`` only number of ``bits`` will be - ## used. + ## ``bits`` number of unused bits in ``value``. If ``bits == 0``, all the bits + ## from ``value`` will be used. ## ## If length of ``dest`` is less then number of required bytes to encode ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. var buffer: array[16, byte] - var lenlen = asn1EncodeLength(buffer, len(value) + 1) - var lbits = 0 - if bits != 0: - lbits = len(value) shl 3 - bits - result = 1 + lenlen + 1 + len(value) - if len(dest) >= result: + let bitlen = + if bits != 0: + (len(value) shl 3) - bits + else: + (len(value) shl 3) + + # Number of bytes used + let bytelen = (bitlen + 7) shr 3 + # Number of unused bits + let unused = (8 - (bitlen and 7)) and 7 + let mask = not((1'u8 shl unused) - 1'u8) + var lenlen = asn1EncodeLength(buffer, uint64(bytelen + 1)) + let res = 1 + lenlen + 1 + len(value) + if len(dest) >= res: dest[0] = Asn1Tag.BitString.code() copyMem(addr dest[1], addr buffer[0], lenlen) - dest[1 + lenlen] = cast[byte](lbits) - if len(value) > 0: - copyMem(addr dest[2 + lenlen], unsafeAddr value[0], len(value)) + dest[1 + lenlen] = byte(unused) + if bytelen > 0: + let lastbyte = value[bytelen - 1] + copyMem(addr dest[2 + lenlen], unsafeAddr value[0], bytelen) + # Set unused bits to zero + dest[2 + lenlen + bytelen - 1] = lastbyte and mask + res proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openarray[byte], value: T): int = @@ -296,53 +324,48 @@ proc asn1EncodeTag[T: SomeUnsignedInt](dest: var openarray[byte], if value <= cast[T](0x7F): if len(dest) >= 1: dest[0] = cast[byte](value) - result = 1 + 1 else: var s = 0 + var res = 0 while v != 0: v = v shr 7 s += 7 - inc(result) - if len(dest) >= result: + inc(res) + if len(dest) >= res: var k = 0 while s != 0: s -= 7 dest[k] = cast[byte](((value shr s) and cast[T](0x7F)) or cast[T](0x80)) inc(k) dest[k - 1] = dest[k - 1] and 0x7F'u8 + res proc asn1EncodeOid*(dest: var openarray[byte], value: openarray[int]): int = ## Encode array of integers ``value`` as ASN.1 DER `OBJECT IDENTIFIER` and ## return number of bytes (octets) used. ## - ## OBJECT IDENTIFIER requirements for ``value`` elements: - ## * len(value) >= 2 - ## * value[0] >= 1 and value[0] < 2 - ## * value[1] >= 1 and value[1] < 39 - ## ## If length of ``dest`` is less then number of required bytes to encode ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. var buffer: array[16, byte] - result = 1 - doAssert(len(value) >= 2) - doAssert(value[0] >= 1 and value[0] < 2) - doAssert(value[1] >= 1 and value[1] <= 39) + var res = 1 var oidlen = 1 for i in 2..= result: + res += asn1EncodeLength(buffer, uint64(oidlen)) + res += oidlen + if len(dest) >= res: let last = dest.high var offset = 1 dest[0] = Asn1Tag.Oid.code() - offset += asn1EncodeLength(dest.toOpenArray(offset, last), oidlen) + offset += asn1EncodeLength(dest.toOpenArray(offset, last), uint64(oidlen)) dest[offset] = cast[byte](value[0] * 40 + value[1]) offset += 1 for i in 2..= result: + let lenlen = asn1EncodeLength(buffer, uint64(len(value))) + let res = 1 + lenlen + len(value) + if len(dest) >= res: dest[0] = Asn1Tag.Oid.code() copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) + res proc asn1EncodeSequence*(dest: var openarray[byte], value: openarray[byte]): int = @@ -371,12 +395,13 @@ proc asn1EncodeSequence*(dest: var openarray[byte], ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. var buffer: array[16, byte] - var lenlen = asn1EncodeLength(buffer, len(value)) - result = 1 + lenlen + len(value) - if len(dest) >= result: + let lenlen = asn1EncodeLength(buffer, uint64(len(value))) + let res = 1 + lenlen + len(value) + if len(dest) >= res: dest[0] = Asn1Tag.Sequence.code() copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) + res proc asn1EncodeComposite*(dest: var openarray[byte], value: Asn1Composite): int = @@ -386,29 +411,34 @@ proc asn1EncodeComposite*(dest: var openarray[byte], ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. var buffer: array[16, byte] - var lenlen = asn1EncodeLength(buffer, len(value.buffer)) - result = 1 + lenlen + len(value.buffer) - if len(dest) >= result: + let lenlen = asn1EncodeLength(buffer, uint64(len(value.buffer))) + let res = 1 + lenlen + len(value.buffer) + if len(dest) >= res: dest[0] = value.tag.code() copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1 + lenlen], unsafeAddr value.buffer[0], len(value.buffer)) + res proc asn1EncodeContextTag*(dest: var openarray[byte], value: openarray[byte], tag: int): int = ## Encode ASN.1 DER `CONTEXT SPECIFIC TAG` ``tag`` for value ``value`` and ## return number of bytes (octets) used. ## + ## Note: Only values in [0, 15] range can be used as context tag ``tag`` + ## values. + ## ## If length of ``dest`` is less then number of required bytes to encode ## ``value``, then result of encoding will not be stored in ``dest`` ## but number of bytes (octets) required will be returned. var buffer: array[16, byte] - var lenlen = asn1EncodeLength(buffer, len(value)) - result = 1 + lenlen + len(value) - if len(dest) >= result: - dest[0] = 0xA0'u8 or (cast[byte](tag) and 0x0F) + let lenlen = asn1EncodeLength(buffer, uint64(len(value))) + let res = 1 + lenlen + len(value) + if len(dest) >= res: + dest[0] = 0xA0'u8 or (byte(tag and 0xFF) and 0x0F'u8) copyMem(addr dest[1], addr buffer[0], lenlen) copyMem(addr dest[1 + lenlen], unsafeAddr value[0], len(value)) + res proc getLength(ab: var Asn1Buffer): Asn1Result[uint64] = ## Decode length part of ASN.1 TLV triplet. @@ -457,197 +487,300 @@ proc read*(ab: var Asn1Buffer): Asn1Result[Asn1Field] = field: Asn1Field tag, ttag, offset: int length, tlength: uint64 - klass: Asn1Class + aclass: Asn1Class inclass: bool inclass = false while true: offset = ab.offset - klass = ? ab.getTag(tag) + aclass = ? ab.getTag(tag) - if klass == Asn1Class.ContextSpecific: + case aclass + of Asn1Class.ContextSpecific: if inclass: return err(Asn1Error.Incorrect) - - inclass = true - ttag = tag - tlength = ? ab.getLength() - - elif klass == Asn1Class.Universal: + else: + inclass = true + ttag = tag + tlength = ? ab.getLength() + of Asn1Class.Universal: length = ? ab.getLength() if inclass: if length >= tlength: return err(Asn1Error.Incorrect) - if cast[byte](tag) == Asn1Tag.Boolean.code(): + case byte(tag) + of Asn1Tag.Boolean.code(): # BOOLEAN if length != 1: return err(Asn1Error.Incorrect) - if not ab.isEnough(cast[int](length)): + + if not ab.isEnough(int(length)): return err(Asn1Error.Incomplete) + let b = ab.buffer[ab.offset] if b != 0xFF'u8 and b != 0x00'u8: return err(Asn1Error.Incorrect) - field = Asn1Field(kind: Asn1Tag.Boolean, klass: klass, - index: ttag, offset: cast[int](ab.offset), + field = Asn1Field(kind: Asn1Tag.Boolean, klass: aclass, + index: ttag, offset: int(ab.offset), length: 1) shallowCopy(field.buffer, ab.buffer) field.vbool = (b == 0xFF'u8) ab.offset += 1 return ok(field) - elif cast[byte](tag) == Asn1Tag.Integer.code(): + + of Asn1Tag.Integer.code(): # INTEGER - if not ab.isEnough(cast[int](length)): - return err(Asn1Error.Incomplete) - if ab.buffer[ab.offset] == 0x00'u8: - length -= 1 - ab.offset += 1 - field = Asn1Field(kind: Asn1Tag.Integer, klass: klass, - index: ttag, offset: cast[int](ab.offset), - length: cast[int](length)) - shallowCopy(field.buffer, ab.buffer) - if length <= 8: - for i in 0.. 1: + return err(Asn1Error.Incorrect) + + if zc == 0: + # Negative or Positive integer + field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass, + index: ttag, offset: int(ab.offset), + length: int(length)) + shallowCopy(field.buffer, ab.buffer) + if (ab.buffer[ab.offset] and 0x80'u8) == 0x80'u8: + # Negative integer + if length <= 8: + # We need this transformation because our field.vint is uint64. + for i in 0 ..< 8: + if i < 8 - int(length): + field.vint = (field.vint shl 8) or 0xFF'u64 + else: + let offset = ab.offset + i - (8 - int(length)) + field.vint = (field.vint shl 8) or uint64(ab.buffer[offset]) + else: + # Positive integer + if length <= 8: + for i in 0 ..< int(length): + field.vint = (field.vint shl 8) or + uint64(ab.buffer[ab.offset + i]) + ab.offset += int(length) + return ok(field) + else: + if length == 1: + # Zero value integer + field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass, + index: ttag, offset: int(ab.offset), + length: int(length), vint: 0'u64) + shallowCopy(field.buffer, ab.buffer) + ab.offset += int(length) + return ok(field) + else: + # Positive integer with leading zero + field = Asn1Field(kind: Asn1Tag.Integer, klass: aclass, + index: ttag, offset: int(ab.offset) + 1, + length: int(length) - 1) + shallowCopy(field.buffer, ab.buffer) + if length <= 9: + for i in 1 ..< int(length): + field.vint = (field.vint shl 8) or + uint64(ab.buffer[ab.offset + i]) + ab.offset += int(length) + return ok(field) + + of Asn1Tag.BitString.code(): # BIT STRING - if not ab.isEnough(cast[int](length)): + if length == 0: + # BIT STRING should include `unused` bits field, so length should be + # bigger then 1. + return err(Asn1Error.Incorrect) + + elif length == 1: + if ab.buffer[ab.offset] != 0x00'u8: + return err(Asn1Error.Incorrect) + else: + # Zero-length BIT STRING. + field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass, + index: ttag, offset: int(ab.offset + 1), + length: 0, ubits: 0) + shallowCopy(field.buffer, ab.buffer) + ab.offset += int(length) + return ok(field) + + else: + if not ab.isEnough(int(length)): + return err(Asn1Error.Incomplete) + + let unused = ab.buffer[ab.offset] + if unused > 0x07'u8: + # Number of unused bits should not be bigger then `7`. + return err(Asn1Error.Incorrect) + + let mask = (1'u8 shl int(unused)) - 1'u8 + if (ab.buffer[ab.offset + int(length) - 1] and mask) != 0x00'u8: + ## All unused bits should be set to `0`. + return err(Asn1Error.Incorrect) + + field = Asn1Field(kind: Asn1Tag.BitString, klass: aclass, + index: ttag, offset: int(ab.offset + 1), + length: int(length - 1), ubits: int(unused)) + shallowCopy(field.buffer, ab.buffer) + ab.offset += int(length) + return ok(field) + + of Asn1Tag.OctetString.code(): + # OCTET STRING + if not ab.isEnough(int(length)): return err(Asn1Error.Incomplete) - field = Asn1Field(kind: Asn1Tag.BitString, klass: klass, - index: ttag, offset: cast[int](ab.offset + 1), - length: cast[int](length - 1)) + + field = Asn1Field(kind: Asn1Tag.OctetString, klass: aclass, + index: ttag, offset: int(ab.offset), + length: int(length)) shallowCopy(field.buffer, ab.buffer) - field.ubits = cast[int](((length - 1) shl 3) - ab.buffer[ab.offset]) - ab.offset += cast[int](length) + ab.offset += int(length) return ok(field) - elif cast[byte](tag) == Asn1Tag.OctetString.code(): - # OCT STRING - if not ab.isEnough(cast[int](length)): - return err(Asn1Error.Incomplete) - field = Asn1Field(kind: Asn1Tag.OctetString, klass: klass, - index: ttag, offset: cast[int](ab.offset), - length: cast[int](length)) - shallowCopy(field.buffer, ab.buffer) - ab.offset += cast[int](length) - return ok(field) - elif cast[byte](tag) == Asn1Tag.Null.code(): + + of Asn1Tag.Null.code(): # NULL if length != 0: return err(Asn1Error.Incorrect) - field = Asn1Field(kind: Asn1Tag.Null, klass: klass, - index: ttag, offset: cast[int](ab.offset), - length: 0) + + field = Asn1Field(kind: Asn1Tag.Null, klass: aclass, index: ttag, + offset: int(ab.offset), length: 0) shallowCopy(field.buffer, ab.buffer) - ab.offset += cast[int](length) + ab.offset += int(length) return ok(field) - elif cast[byte](tag) == Asn1Tag.Oid.code(): + + of Asn1Tag.Oid.code(): # OID - if not ab.isEnough(cast[int](length)): + if not ab.isEnough(int(length)): return err(Asn1Error.Incomplete) - field = Asn1Field(kind: Asn1Tag.Oid, klass: klass, - index: ttag, offset: cast[int](ab.offset), - length: cast[int](length)) + + field = Asn1Field(kind: Asn1Tag.Oid, klass: aclass, + index: ttag, offset: int(ab.offset), + length: int(length)) shallowCopy(field.buffer, ab.buffer) - ab.offset += cast[int](length) + ab.offset += int(length) return ok(field) - elif cast[byte](tag) == Asn1Tag.Sequence.code(): + + of Asn1Tag.Sequence.code(): # SEQUENCE - if not ab.isEnough(cast[int](length)): + if not ab.isEnough(int(length)): return err(Asn1Error.Incomplete) - field = Asn1Field(kind: Asn1Tag.Sequence, klass: klass, - index: ttag, offset: cast[int](ab.offset), - length: cast[int](length)) + + field = Asn1Field(kind: Asn1Tag.Sequence, klass: aclass, + index: ttag, offset: int(ab.offset), + length: int(length)) shallowCopy(field.buffer, ab.buffer) - ab.offset += cast[int](length) + ab.offset += int(length) return ok(field) + else: return err(Asn1Error.NoSupport) + inclass = false ttag = 0 else: return err(Asn1Error.NoSupport) -proc getBuffer*(field: Asn1Field): Asn1Buffer = +proc getBuffer*(field: Asn1Field): Asn1Buffer {.inline.} = ## Return ``field`` as Asn1Buffer to enter composite types. - shallowCopy(result.buffer, field.buffer) - result.offset = field.offset - result.length = field.length + Asn1Buffer(buffer: field.buffer, offset: field.offset, length: field.length) proc `==`*(field: Asn1Field, data: openarray[byte]): bool = ## Compares field ``field`` data with ``data`` and returns ``true`` if both ## buffers are equal. let length = len(field.buffer) - if length > 0: - if field.length == len(data): - result = equalMem(unsafeAddr field.buffer[field.offset], - unsafeAddr data[0], field.length) + if length == 0 and len(data) == 0: + true + else: + if length > 0: + if field.length == len(data): + CT.isEqual( + field.buffer.toOpenArray(field.offset, + field.offset + field.length - 1), + data.toOpenArray(0, field.length - 1)) + else: + false + else: + false proc init*(t: typedesc[Asn1Buffer], data: openarray[byte]): Asn1Buffer = ## Initialize ``Asn1Buffer`` from array of bytes ``data``. - result.buffer = @data + Asn1Buffer(buffer: @data) proc init*(t: typedesc[Asn1Buffer], data: string): Asn1Buffer = ## Initialize ``Asn1Buffer`` from hexadecimal string ``data``. - result.buffer = fromHex(data) + Asn1Buffer(buffer: ncrutils.fromHex(data)) proc init*(t: typedesc[Asn1Buffer]): Asn1Buffer = ## Initialize empty ``Asn1Buffer``. - result.buffer = newSeq[byte]() + Asn1Buffer(buffer: newSeq[byte]()) proc init*(t: typedesc[Asn1Composite], tag: Asn1Tag): Asn1Composite = ## Initialize ``Asn1Composite`` with tag ``tag``. - result.tag = tag - result.buffer = newSeq[byte]() + Asn1Composite(tag: tag, buffer: newSeq[byte]()) proc init*(t: typedesc[Asn1Composite], idx: int): Asn1Composite = ## Initialize ``Asn1Composite`` with tag context-specific id ``id``. - result.tag = Asn1Tag.Context - result.idx = idx - result.buffer = newSeq[byte]() + Asn1Composite(tag: Asn1Tag.Context, idx: idx, buffer: newSeq[byte]()) proc `$`*(buffer: Asn1Buffer): string = ## Return string representation of ``buffer``. - result = toHex(buffer.toOpenArray()) + ncrutils.toHex(buffer.toOpenArray()) proc `$`*(field: Asn1Field): string = ## Return string representation of ``field``. - result = "[" - result.add($field.kind) - result.add("]") - if field.kind == Asn1Tag.NoSupport: - result.add(" ") - result.add(toHex(field.toOpenArray())) - elif field.kind == Asn1Tag.Boolean: - result.add(" ") - result.add($field.vbool) - elif field.kind == Asn1Tag.Integer: - result.add(" ") + var res = "[" + res.add($field.kind) + res.add("]") + case field.kind + of Asn1Tag.Boolean: + res.add(" ") + res.add($field.vbool) + res + of Asn1Tag.Integer: + res.add(" ") if field.length <= 8: - result.add($field.vint) + res.add($field.vint) else: - result.add(toHex(field.toOpenArray())) - elif field.kind == Asn1Tag.BitString: - result.add(" ") - result.add("(") - result.add($field.ubits) - result.add(" bits) ") - result.add(toHex(field.toOpenArray())) - elif field.kind == Asn1Tag.OctetString: - result.add(" ") - result.add(toHex(field.toOpenArray())) - elif field.kind == Asn1Tag.Null: - result.add(" NULL") - elif field.kind == Asn1Tag.Oid: - result.add(" ") - result.add(toHex(field.toOpenArray())) - elif field.kind == Asn1Tag.Sequence: - result.add(" ") - result.add(toHex(field.toOpenArray())) + res.add(ncrutils.toHex(field.toOpenArray())) + res + of Asn1Tag.BitString: + res.add(" ") + res.add("(") + res.add($field.ubits) + res.add(" bits) ") + res.add(ncrutils.toHex(field.toOpenArray())) + res + of Asn1Tag.OctetString: + res.add(" ") + res.add(ncrutils.toHex(field.toOpenArray())) + res + of Asn1Tag.Null: + res.add(" NULL") + res + of Asn1Tag.Oid: + res.add(" ") + res.add(ncrutils.toHex(field.toOpenArray())) + res + of Asn1Tag.Sequence: + res.add(" ") + res.add(ncrutils.toHex(field.toOpenArray())) + res + of Asn1Tag.Context: + res.add(" ") + res.add(ncrutils.toHex(field.toOpenArray())) + res + else: + res.add(" ") + res.add(ncrutils.toHex(field.toOpenArray())) + res proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, tag: Asn1Tag) = ## Write empty value to buffer or composite with ``tag``. @@ -655,7 +788,7 @@ proc write*[T: Asn1Buffer|Asn1Composite](abc: var T, tag: Asn1Tag) = ## This procedure must be used to write `NULL`, `0` or empty `BIT STRING`, ## `OCTET STRING` types. doAssert(tag in {Asn1Tag.Null, Asn1Tag.Integer, Asn1Tag.BitString, - Asn1Tag.OctetString}) + Asn1Tag.OctetString}) var length: int if tag == Asn1Tag.Null: length = asn1EncodeNull(abc.toOpenArray()) diff --git a/libp2p/crypto/rsa.nim b/libp2p/crypto/rsa.nim index 835c9370a..6eee3c520 100644 --- a/libp2p/crypto/rsa.nim +++ b/libp2p/crypto/rsa.nim @@ -14,13 +14,13 @@ ## Copyright(C) 2018 Thomas Pornin . {.push raises: Defect.} - -import nimcrypto/utils import bearssl import minasn1 -export Asn1Error import stew/[results, ctops] -export results +# We use `ncrutils` for constant-time hexadecimal encoding/decoding procedures. +import nimcrypto/utils as ncrutils + +export Asn1Error, results const DefaultPublicExponent* = 65537'u32 @@ -574,14 +574,16 @@ proc init*(sig: var RsaSignature, data: openarray[byte]): Result[void, Asn1Error else: err(Asn1Error.Incorrect) -proc init*[T: RsaPKI](sospk: var T, data: string): Result[void, Asn1Error] {.inline.} = +proc init*[T: RsaPKI](sospk: var T, + data: string): Result[void, Asn1Error] {.inline.} = ## Initialize EC `private key`, `public key` or `scalar` ``sospk`` from ## hexadecimal string representation ``data``. ## ## Procedure returns ``Result[void, Asn1Status]``. - sospk.init(fromHex(data)) + sospk.init(ncrutils.fromHex(data)) -proc init*(t: typedesc[RsaPrivateKey], data: openarray[byte]): RsaResult[RsaPrivateKey] = +proc init*(t: typedesc[RsaPrivateKey], + data: openarray[byte]): RsaResult[RsaPrivateKey] = ## Initialize RSA private key from ASN.1 DER binary representation ``data`` ## and return constructed object. var res: RsaPrivateKey @@ -590,7 +592,8 @@ proc init*(t: typedesc[RsaPrivateKey], data: openarray[byte]): RsaResult[RsaPriv else: ok(res) -proc init*(t: typedesc[RsaPublicKey], data: openarray[byte]): RsaResult[RsaPublicKey] = +proc init*(t: typedesc[RsaPublicKey], + data: openarray[byte]): RsaResult[RsaPublicKey] = ## Initialize RSA public key from ASN.1 DER binary representation ``data`` ## and return constructed object. var res: RsaPublicKey @@ -599,7 +602,8 @@ proc init*(t: typedesc[RsaPublicKey], data: openarray[byte]): RsaResult[RsaPubli else: ok(res) -proc init*(t: typedesc[RsaSignature], data: openarray[byte]): RsaResult[RsaSignature] = +proc init*(t: typedesc[RsaSignature], + data: openarray[byte]): RsaResult[RsaSignature] = ## Initialize RSA signature from raw binary representation ``data`` and ## return constructed object. var res: RsaSignature @@ -611,7 +615,7 @@ proc init*(t: typedesc[RsaSignature], data: openarray[byte]): RsaResult[RsaSigna proc init*[T: RsaPKI](t: typedesc[T], data: string): T {.inline.} = ## Initialize RSA `private key`, `public key` or `signature` from hexadecimal ## string representation ``data`` and return constructed object. - result = t.init(fromHex(data)) + result = t.init(ncrutils.fromHex(data)) proc `$`*(key: RsaPrivateKey): string = ## Return string representation of RSA private key. @@ -622,21 +626,24 @@ proc `$`*(key: RsaPrivateKey): string = result.add($key.seck.nBitlen) result.add(" bits)\n") result.add("p = ") - result.add(toHex(getArray(key.buffer, key.seck.p, key.seck.plen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.seck.p, key.seck.plen))) result.add("\nq = ") - result.add(toHex(getArray(key.buffer, key.seck.q, key.seck.qlen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.seck.q, key.seck.qlen))) result.add("\ndp = ") - result.add(toHex(getArray(key.buffer, key.seck.dp, key.seck.dplen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dp, + key.seck.dplen))) result.add("\ndq = ") - result.add(toHex(getArray(key.buffer, key.seck.dq, key.seck.dqlen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.seck.dq, + key.seck.dqlen))) result.add("\niq = ") - result.add(toHex(getArray(key.buffer, key.seck.iq, key.seck.iqlen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.seck.iq, + key.seck.iqlen))) result.add("\npre = ") - result.add(toHex(getArray(key.buffer, key.pexp, key.pexplen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.pexp, key.pexplen))) result.add("\nm = ") - result.add(toHex(getArray(key.buffer, key.pubk.n, key.pubk.nlen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.pubk.n, key.pubk.nlen))) result.add("\npue = ") - result.add(toHex(getArray(key.buffer, key.pubk.e, key.pubk.elen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.pubk.e, key.pubk.elen))) result.add("\n") proc `$`*(key: RsaPublicKey): string = @@ -648,9 +655,9 @@ proc `$`*(key: RsaPublicKey): string = result = "RSA key (" result.add($nbitlen) result.add(" bits)\nn = ") - result.add(toHex(getArray(key.buffer, key.key.n, key.key.nlen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.key.n, key.key.nlen))) result.add("\ne = ") - result.add(toHex(getArray(key.buffer, key.key.e, key.key.elen))) + result.add(ncrutils.toHex(getArray(key.buffer, key.key.e, key.key.elen))) result.add("\n") proc `$`*(sig: RsaSignature): string = @@ -659,7 +666,7 @@ proc `$`*(sig: RsaSignature): string = result = "Empty or uninitialized RSA signature" else: result = "RSA signature (" - result.add(toHex(sig.buffer)) + result.add(ncrutils.toHex(sig.buffer)) result.add(")") proc `==`*(a, b: RsaPrivateKey): bool = diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 51d261327..27e653e44 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -138,12 +138,9 @@ proc closeRemote*(s: LPChannel) {.async.} = trace "got EOF, closing channel" try: await s.drainBuffer() - s.isEof = true # set EOF immediately to prevent further reads - await s.close() # close local end - - # call to avoid leaks - await procCall BufferStream(s).close() # close parent bufferstream + # close parent bufferstream to prevent further reads + await procCall BufferStream(s).close() trace "channel closed on EOF" except CancelledError as exc: diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 317e056a5..e6b519d04 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -96,7 +96,7 @@ proc newStreamInternal*(m: Mplex, proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} = ## remove the local channel from the internal tables ## - await chann.closeEvent.wait() + await chann.join() if not isNil(chann): m.getChannelList(chann.initiator).del(chann.id) trace "cleaned up channel", id = chann.id diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index 9b6feb772..3ec69adde 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -31,14 +31,9 @@ type method subscribeTopic*(f: FloodSub, topic: string, subscribe: bool, - peerId: string) {.gcsafe, async.} = + peerId: PeerID) {.gcsafe, async.} = await procCall PubSub(f).subscribeTopic(topic, subscribe, peerId) - let peer = f.peers.getOrDefault(peerId) - if peer == nil: - debug "subscribeTopic on a nil peer!", peer = peerId - return - if topic notin f.floodsub: f.floodsub[topic] = initHashSet[PubSubPeer]() @@ -51,16 +46,20 @@ method subscribeTopic*(f: FloodSub, # unsubscribe the peer from the topic f.floodsub[topic].excl(peer) -method handleDisconnect*(f: FloodSub, peer: PubSubPeer) = +method unsubscribePeer*(f: FloodSub, peer: PeerID) = ## handle peer disconnects ## - - procCall PubSub(f).handleDisconnect(peer) - if not(isNil(peer)) and peer.peerInfo notin f.conns: - for t in toSeq(f.floodsub.keys): - if t in f.floodsub: - f.floodsub[t].excl(peer) + trace "unsubscribing floodsub peer", peer = $peer + let pubSubPeer = f.peers.getOrDefault(peer) + if pubSubPeer.isNil: + return + + for t in toSeq(f.floodsub.keys): + if t in f.floodsub: + f.floodsub[t].excl(pubSubPeer) + + procCall PubSub(f).unsubscribePeer(peer) method rpcHandler*(f: FloodSub, peer: PubSubPeer, @@ -77,7 +76,7 @@ method rpcHandler*(f: FloodSub, if msgId notin f.seen: f.seen.put(msgId) # add the message to the seen cache - if f.verifySignature and not msg.verify(peer.peerInfo): + if f.verifySignature and not msg.verify(peer.peerId): trace "dropping message due to failed signature verification" continue @@ -102,7 +101,10 @@ method rpcHandler*(f: FloodSub, trace "exception in message handler", exc = exc.msg # forward the message to all peers interested in it - let published = await f.publishHelper(toSendPeers, m.messages, DefaultSendTimeout) + let published = await f.broadcast( + toSeq(toSendPeers), + RPCMsg(messages: m.messages), + DefaultSendTimeout) trace "forwared message to peers", peers = published @@ -118,11 +120,6 @@ method init*(f: FloodSub) = f.handler = handler f.codec = FloodSubCodec -method subscribePeer*(p: FloodSub, - conn: Connection) = - procCall PubSub(p).subscribePeer(conn) - asyncCheck p.handleConn(conn, FloodSubCodec) - method publish*(f: FloodSub, topic: string, data: seq[byte], @@ -143,7 +140,10 @@ method publish*(f: FloodSub, let msg = Message.init(f.peerInfo, data, topic, f.msgSeqno, f.sign) # start the future but do not wait yet - let published = await f.publishHelper(f.floodsub.getOrDefault(topic), @[msg], timeout) + let published = await f.broadcast( + toSeq(f.floodsub.getOrDefault(topic)), + RPCMsg(messages: @[msg]), + timeout) when defined(libp2p_expensive_metrics): libp2p_pubsub_messages_published.inc(labelValues = [topic]) @@ -167,8 +167,6 @@ method unsubscribeAll*(f: FloodSub, topic: string) {.async.} = method initPubSub*(f: FloodSub) = procCall PubSub(f).initPubSub() - f.peers = initTable[string, PubSubPeer]() - f.topics = initTable[string, Topic]() f.floodsub = initTable[string, HashSet[PubSubPeer]]() f.seen = newTimedCache[string](2.minutes) f.init() diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 5236a6051..b4b636657 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -404,10 +404,10 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = .set(g.mesh.peers(topic).int64, labelValues = [topic]) # Send changes to peers after table updates to avoid stale state - for p in grafts: - await p.sendGraft(@[topic]) - for p in prunes: - await p.sendPrune(@[topic]) + let graft = RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)]))) + let prune = RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)]))) + discard await g.broadcast(grafts, graft, DefaultSendTimeout) + discard await g.broadcast(prunes, prune, DefaultSendTimeout) trace "mesh balanced, got peers", peers = g.mesh.peers(topic) @@ -426,7 +426,7 @@ proc dropFanoutPeers(g: GossipSub) = libp2p_gossipsub_peers_per_topic_fanout .set(g.fanout.peers(topic).int64, labelValues = [topic]) -proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} = +proc getGossipPeers(g: GossipSub): Table[PubSubPeer, ControlMessage] {.gcsafe.} = ## gossip iHave messages to peers ## @@ -458,10 +458,10 @@ proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} = if peer in gossipPeers: continue - if peer.id notin result: - result[peer.id] = controlMsg + if peer notin result: + result[peer] = controlMsg - result[peer.id].ihave.add(ihave) + result[peer].ihave.add(ihave) func `/`(a, b: Duration): float64 = let @@ -582,8 +582,11 @@ proc heartbeat(g: GossipSub) {.async.} = let peers = g.getGossipPeers() var sent: seq[Future[void]] for peer, control in peers: - g.peers.withValue(peer, pubsubPeer) do: - sent &= pubsubPeer[].send(RPCMsg(control: some(control))) + g.peers.withValue(peer.peerId, pubsubPeer) do: + sent &= g.send( + pubsubPeer[], + RPCMsg(control: some(control)), + DefaultSendTimeout) checkFutures(await allFinished(sent)) g.mcache.shift() # shift the cache @@ -599,35 +602,37 @@ proc heartbeat(g: GossipSub) {.async.} = await sleepAsync(GossipSubHeartbeatInterval) -method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = +method unsubscribePeer*(g: GossipSub, peer: PeerID) = ## handle peer disconnects - ## - - procCall FloodSub(g).handleDisconnect(peer) + ## + + trace "unsubscribing gossipsub peer", peer = $peer + let pubSubPeer = g.peers.getOrDefault(peer) + if pubSubPeer.isNil: + return + + for t in toSeq(g.gossipsub.keys): + g.gossipsub.removePeer(t, pubSubPeer) - if not(isNil(peer)) and peer.peerInfo notin g.conns: - for t in toSeq(g.gossipsub.keys): - g.gossipsub.removePeer(t, peer) - when defined(libp2p_expensive_metrics): - libp2p_gossipsub_peers_per_topic_gossipsub - .set(g.gossipsub.peers(t).int64, labelValues = [t]) + libp2p_gossipsub_peers_per_topic_gossipsub + .set(g.gossipsub.peers(t).int64, labelValues = [t]) - for t in toSeq(g.mesh.keys): - if peer in g.mesh[t]: + for t in toSeq(g.mesh.keys): + if peer in g.mesh[t]: g.pruned(peer, t) - g.mesh.removePeer(t, peer) + g.mesh.removePeer(t, pubSubPeer) - when defined(libp2p_expensive_metrics): - libp2p_gossipsub_peers_per_topic_mesh - .set(g.mesh.peers(t).int64, labelValues = [t]) + when defined(libp2p_expensive_metrics): + libp2p_gossipsub_peers_per_topic_mesh + .set(g.mesh.peers(t).int64, labelValues = [t]) - for t in toSeq(g.fanout.keys): - g.fanout.removePeer(t, peer) + for t in toSeq(g.fanout.keys): + g.fanout.removePeer(t, pubSubPeer) - when defined(libp2p_expensive_metrics): - libp2p_gossipsub_peers_per_topic_fanout - .set(g.fanout.peers(t).int64, labelValues = [t]) + when defined(libp2p_expensive_metrics): + libp2p_gossipsub_peers_per_topic_fanout + .set(g.fanout.peers(t).int64, labelValues = [t]) # TODO # if peer.peerInfo.maintain: @@ -644,19 +649,16 @@ method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = for topic, info in g.peerStats[peer].topicInfos.mpairs: info.firstMessageDeliveries = 0 -method subscribePeer*(p: GossipSub, - conn: Connection) = - procCall PubSub(p).subscribePeer(conn) - asyncCheck p.handleConn(conn, GossipSubCodec) + procCall FloodSub(g).unsubscribePeer(peer) method subscribeTopic*(g: GossipSub, topic: string, subscribe: bool, - peerId: string) {.gcsafe, async.} = + peerId: PeerID) {.gcsafe, async.} = await procCall FloodSub(g).subscribeTopic(topic, subscribe, peerId) logScope: - peer = peerId + peer = $peerId topic let peer = g.peers.getOrDefault(peerId) @@ -817,8 +819,8 @@ method rpcHandler*(g: GossipSub, g.seen.put(msgId) # add the message to the seen cache - if g.verifySignature and not msg.verify(peer.peerInfo): - trace "dropping message due to failed signature verification", peer + if g.verifySignature and not msg.verify(peer.peerId): + trace "dropping message due to failed signature verification" g.punishPeer(peer, msg) continue @@ -872,7 +874,10 @@ method rpcHandler*(g: GossipSub, trace "exception in message handler", exc = exc.msg # forward the message to all peers interested in it - let published = await g.publishHelper(toSendPeers, m.messages, DefaultSendTimeout) + let published = await g.broadcast( + toSeq(toSendPeers), + RPCMsg(messages: m.messages), + DefaultSendTimeout) trace "forwared message to peers", peers = published @@ -889,8 +894,10 @@ method rpcHandler*(g: GossipSub, respControl.ihave.len > 0: try: info "sending control message", msg = respControl - await peer.send( - RPCMsg(control: some(respControl), messages: messages)) + await g.send( + peer, + RPCMsg(control: some(respControl), messages: messages), + DefaultSendTimeout) except CancelledError as exc: raise exc except CatchableError as exc: @@ -917,12 +924,10 @@ method unsubscribe*(g: GossipSub, if topic in g.mesh: let peers = g.mesh.getOrDefault(topic) g.mesh.del(topic) - - var pending = newSeq[Future[void]]() for peer in peers: g.pruned(peer, topic) - pending.add(peer.sendPrune(@[topic])) - checkFutures(await allFinished(pending)) + let prune = RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)]))) + discard await g.broadcast(toSeq(peers), prune, DefaultSendTimeout) method unsubscribeAll*(g: GossipSub, topic: string) {.async.} = await procCall PubSub(g).unsubscribeAll(topic) @@ -930,12 +935,10 @@ method unsubscribeAll*(g: GossipSub, topic: string) {.async.} = if topic in g.mesh: let peers = g.mesh.getOrDefault(topic) g.mesh.del(topic) - - var pending = newSeq[Future[void]]() for peer in peers: g.pruned(peer, topic) - pending.add(peer.sendPrune(@[topic])) - checkFutures(await allFinished(pending)) + let prune = RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)]))) + discard await g.broadcast(toSeq(peers), prune, DefaultSendTimeout) method publish*(g: GossipSub, topic: string, @@ -986,7 +989,7 @@ method publish*(g: GossipSub, if msgId notin g.mcache: g.mcache.put(msgId, msg) - let published = await g.publishHelper(peers, @[msg], timeout) + let published = await g.broadcast(toSeq(peers), RPCMsg(messages: @[msg]), timeout) when defined(libp2p_expensive_metrics): if published > 0: libp2p_pubsub_messages_published.inc(labelValues = [topic]) diff --git a/libp2p/protocols/pubsub/peertable.nim b/libp2p/protocols/pubsub/peertable.nim index eda623168..d294c0155 100644 --- a/libp2p/protocols/pubsub/peertable.nim +++ b/libp2p/protocols/pubsub/peertable.nim @@ -13,10 +13,10 @@ import pubsubpeer, ../../peerid type PeerTable* = Table[string, HashSet[PubSubPeer]] # topic string to peer map -proc hasPeerID*(t: PeerTable, topic, peerId: string): bool = +proc hasPeerID*(t: PeerTable, topic: string, peerId: PeerID): bool = let peers = toSeq(t.getOrDefault(topic)) peers.any do (peer: PubSubPeer) -> bool: - peer.id == peerId + peer.peerId == peerId func addPeer*(table: var PeerTable, topic: string, peer: PubSubPeer): bool = # returns true if the peer was added, diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 94f7e8b0d..5f3c176d8 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -11,6 +11,7 @@ import std/[tables, sequtils, sets] import chronos, chronicles, metrics import pubsubpeer, rpc/[message, messages], + ../../switch, ../protocol, ../../stream/connection, ../../peerid, @@ -53,64 +54,77 @@ type handler*: seq[TopicHandler] PubSub* = ref object of LPProtocol + switch*: Switch # the switch used to dial/connect to peers peerInfo*: PeerInfo # this peer's info topics*: Table[string, Topic] # local topics - peers*: Table[string, PubSubPeer] # peerid to peer map - conns*: Table[PeerInfo, HashSet[Connection]] # peers connections + peers*: Table[PeerID, PubSubPeer] # peerid to peer map triggerSelf*: bool # trigger own local handler on publish verifySignature*: bool # enable signature verification sign*: bool # enable message signing cleanupLock: AsyncLock validators*: Table[string, HashSet[ValidatorHandler]] - observers: ref seq[PubSubObserver] # ref as in smart_ptr - msgIdProvider*: MsgIdProvider # Turn message into message id (not nil) + observers: ref seq[PubSubObserver] # ref as in smart_ptr + msgIdProvider*: MsgIdProvider # Turn message into message id (not nil) msgSeqno*: uint64 + lifetimeFut*: Future[void] # pubsub liftime future -method handleConnect*(p: PubSub, peer: PubSubPeer) {.base.} = - discard - -method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} = +method unsubscribePeer*(p: PubSub, peerId: PeerID) {.base.} = ## handle peer disconnects ## - - if not(isNil(peer)) and peer.peerInfo notin p.conns: - trace "deleting peer", peer = peer.id - peer.onConnect.fire() # Make sure all pending sends are unblocked - p.peers.del(peer.id) - trace "peer disconnected", peer = peer.id - # metrics - libp2p_pubsub_peers.set(p.peers.len.int64) + trace "unsubscribing pubsub peer", peer = $peerId + if peerId in p.peers: + p.peers.del(peerId) -proc onConnClose(p: PubSub, conn: Connection) {.async.} = + libp2p_pubsub_peers.set(p.peers.len.int64) + +proc send*( + p: PubSub, + peer: PubSubPeer, + msg: RPCMsg, + timeout: Duration) {.async.} = + ## send to remote peer + ## + + trace "sending pubsub message to peer", peer = $peer, msg = msg try: - let peer = conn.peerInfo - await conn.closeEvent.wait() - - if peer in p.conns: - p.conns[peer].excl(conn) - if p.conns[peer].len <= 0: - p.conns.del(peer) - - if peer.id in p.peers: - p.handleDisconnect(p.peers[peer.id]) - + await peer.send(msg, timeout) except CancelledError as exc: raise exc except CatchableError as exc: - trace "exception in onConnClose handler", exc = exc.msg + trace "exception sending pubsub message to peer", peer = $peer, msg = msg + p.unsubscribePeer(peer.peerId) + raise exc + +proc broadcast*( + p: PubSub, + sendPeers: seq[PubSubPeer], + msg: RPCMsg, + timeout: Duration): Future[int] {.async.} = + ## send messages and cleanup failed peers + ## + + trace "broadcasting messages to peers", peers = sendPeers.len, message = msg + let sent = await allFinished( + sendPeers.mapIt( p.send(it, msg, timeout) )) + return sent.filterIt( it.finished and it.error.isNil ).len + trace "messages broadcasted to peers", peers = sent.len proc sendSubs*(p: PubSub, peer: PubSubPeer, topics: seq[string], subscribe: bool) {.async.} = ## send subscriptions to remote peer - asyncCheck peer.sendSubOpts(topics, subscribe) + await p.send( + peer, + RPCMsg( + subscriptions: topics.mapIt(SubOpts(subscribe: subscribe, topic: it))), + DefaultSendTimeout) method subscribeTopic*(p: PubSub, topic: string, subscribe: bool, - peerId: string) {.base, async.} = + peerId: PeerID) {.base, async.} = # called when remote peer subscribes to a topic var peer = p.peers.getOrDefault(peerId) if not isNil(peer): @@ -130,27 +144,27 @@ method rpcHandler*(p: PubSub, if m.subscriptions.len > 0: # if there are any subscriptions for s in m.subscriptions: # subscribe/unsubscribe the peer for each topic trace "about to subscribe to topic", topicId = s.topic - await p.subscribeTopic(s.topic, s.subscribe, peer.id) + await p.subscribeTopic(s.topic, s.subscribe, peer.peerId) -proc getOrCreatePeer(p: PubSub, - peerInfo: PeerInfo, - proto: string): PubSubPeer = - if peerInfo.id in p.peers: - return p.peers[peerInfo.id] +proc getOrCreatePeer*( + p: PubSub, + peer: PeerID, + proto: string): PubSubPeer = + if peer in p.peers: + return p.peers[peer] # create new pubsub peer - let peer = newPubSubPeer(peerInfo, proto) - trace "created new pubsub peer", peerId = peer.id + let pubSubPeer = newPubSubPeer(peer, p.switch, proto) + trace "created new pubsub peer", peerId = $peer - p.peers[peer.id] = peer - peer.observers = p.observers + p.peers[peer] = pubSubPeer + pubSubPeer.observers = p.observers handleConnect(p, peer) # metrics libp2p_pubsub_peers.set(p.peers.len.int64) - - return peer + return pubSubPeer method handleConn*(p: PubSub, conn: Connection, @@ -171,19 +185,11 @@ method handleConn*(p: PubSub, await conn.close() return - # track connection - p.conns.mgetOrPut(conn.peerInfo, - initHashSet[Connection]()) - .incl(conn) - - asyncCheck p.onConnClose(conn) - proc handler(peer: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = # call pubsub rpc handler await p.rpcHandler(peer, msgs) - let peer = p.getOrCreatePeer(conn.peerInfo, proto) - + let peer = p.getOrCreatePeer(conn.peerInfo.peerId, proto) if p.topics.len > 0: await p.sendSubs(peer, toSeq(p.topics.keys), true) @@ -198,32 +204,16 @@ method handleConn*(p: PubSub, finally: await conn.close() -method subscribePeer*(p: PubSub, conn: Connection) {.base.} = - if not(isNil(conn)): - trace "subscribing to peer", peerId = conn.peerInfo.id +method subscribePeer*(p: PubSub, peer: PeerID) {.base.} = + ## subscribe to remote peer to receive/send pubsub + ## messages + ## - # track connection - p.conns.mgetOrPut(conn.peerInfo, - initHashSet[Connection]()) - .incl(conn) + let pubsubPeer = p.getOrCreatePeer(peer, p.codec) + if p.topics.len > 0: + asyncCheck p.sendSubs(pubsubPeer, toSeq(p.topics.keys), true) - asyncCheck p.onConnClose(conn) - - let peer = p.getOrCreatePeer(conn.peerInfo, p.codec) - if not peer.connected: - peer.conn = conn - -method unsubscribePeer*(p: PubSub, peerInfo: PeerInfo) {.base, async.} = - if peerInfo.id in p.peers: - let peer = p.peers[peerInfo.id] - - trace "unsubscribing from peer", peerId = $peerInfo - if not(isNil(peer)) and not(isNil(peer.conn)): - await peer.conn.close() - -proc connected*(p: PubSub, peerId: PeerID): bool = - p.peers.withValue($peerId, peer): - return peer[] != nil and peer[].connected + pubsubPeer.subscribed = true method unsubscribe*(p: PubSub, topics: seq[TopicPair]) {.base, async.} = @@ -278,40 +268,6 @@ method subscribe*(p: PubSub, # metrics libp2p_pubsub_topics.set(p.topics.len.int64) -proc publishHelper*(p: PubSub, - sendPeers: HashSet[PubSubPeer], - msgs: seq[Message], - timeout: Duration): Future[int] {.async.} = - # send messages and cleanup failed peers - var sent: seq[tuple[id: string, fut: Future[void]]] - for sendPeer in sendPeers: - # avoid sending to self - if sendPeer.peerInfo == p.peerInfo: - continue - - trace "sending messages to peer", peer = sendPeer.id, msgs - sent.add((id: sendPeer.id, fut: sendPeer.send(RPCMsg(messages: msgs), timeout))) - - var published: seq[string] - var failed: seq[string] - let futs = await allFinished(sent.mapIt(it.fut)) - for s in futs: - let f = sent.filterIt(it.fut == s) - if f.len > 0: - if s.failed: - trace "sending messages to peer failed", peer = f[0].id - failed.add(f[0].id) - else: - trace "sending messages to peer succeeded", peer = f[0].id - published.add(f[0].id) - - for f in failed: - let peer = p.peers.getOrDefault(f) - if not(isNil(peer)) and not(isNil(peer.conn)): - await peer.conn.close() - - return published.len - method publish*(p: PubSub, topic: string, data: seq[byte], @@ -381,28 +337,35 @@ method validate*(p: PubSub, message: Message): Future[bool] {.async, base.} = else: libp2p_pubsub_validation_failure.inc() -proc newPubSub*[PubParams: object | bool](P: typedesc[PubSub], - peerInfo: PeerInfo, - triggerSelf: bool = false, - verifySignature: bool = true, - sign: bool = true, - msgIdProvider: MsgIdProvider = defaultMsgIdProvider, - params: PubParams = false): P = +proc init*[PubParams: object | bool]( + P: typedesc[PubSub], + switch: Switch, + triggerSelf: bool = false, + verifySignature: bool = true, + sign: bool = true, + msgIdProvider: MsgIdProvider = defaultMsgIdProvider, + parameters: PubParams = false): P = when PubParams is bool: - result = P(peerInfo: peerInfo, + result = P(switch: switch, + peerInfo: switch.peerInfo, triggerSelf: triggerSelf, verifySignature: verifySignature, sign: sign, + peers: initTable[PeerID, PubSubPeer](), + topics: initTable[string, Topic](), cleanupLock: newAsyncLock(), msgIdProvider: msgIdProvider) else: - result = P(peerInfo: peerInfo, - triggerSelf: triggerSelf, - verifySignature: verifySignature, - sign: sign, - cleanupLock: newAsyncLock(), - msgIdProvider: msgIdProvider, - parameters: params) + result = P(switch: switch, + peerInfo: switch.peerInfo, + triggerSelf: triggerSelf, + verifySignature: verifySignature, + sign: sign, + peers: initTable[PeerID, PubSubPeer](), + topics: initTable[string, Topic](), + cleanupLock: newAsyncLock(), + msgIdProvider: msgIdProvider, + parameters: parameters) result.initPubSub() @@ -412,6 +375,3 @@ proc removeObserver*(p: PubSub; observer: PubSubObserver) = let idx = p.observers[].find(observer) if idx != -1: p.observers[].del(idx) - -proc connected*(p: PubSub, peerInfo: PeerInfo): bool {.deprecated: "Use PeerID version".} = - peerInfo != nil and connected(p, peerInfo.peerId) diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 771cb5643..bc1146f33 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -11,6 +11,7 @@ import std/[hashes, options, sequtils, strutils, tables, hashes, sets] import chronos, chronicles, nimcrypto/sha2, metrics import rpc/[messages, message, protobuf], timedcache, + ../../switch, ../../peerid, ../../peerinfo, ../../stream/connection, @@ -28,7 +29,6 @@ when defined(libp2p_expensive_metrics): declareCounter(libp2p_pubsub_skipped_sent_messages, "number of sent skipped messages", labels = ["id"]) const - DefaultReadTimeout* = 1.minutes DefaultSendTimeout* = 10.seconds type @@ -37,15 +37,17 @@ type onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe, raises: [Defect].} PubSubPeer* = ref object of RootObj - proto*: string # the protocol that this peer joined from + switch*: Switch # switch instance to dial peers + codec*: string # the protocol that this peer joined from sendConn: Connection - peerInfo*: PeerInfo + peerId*: PeerID handler*: RPCHandler topics*: HashSet[string] sentRpcCache: TimedCache[string] # cache for already sent messages recvdRpcCache: TimedCache[string] # cache for already received messages - onConnect*: AsyncEvent observers*: ref seq[PubSubObserver] # ref as in smart_ptr + subscribed*: bool # are we subscribed to this peer + sendLock*: AsyncLock # send connection lock score*: float64 @@ -57,19 +59,13 @@ func hash*(p: PubSubPeer): Hash = # int is either 32/64, so intptr basically, pubsubpeer is a ref cast[pointer](p).hash -proc id*(p: PubSubPeer): string = p.peerInfo.id +proc id*(p: PubSubPeer): string = + doAssert(not p.isNil, "nil pubsubpeer") + p.peerId.pretty proc connected*(p: PubSubPeer): bool = - not(isNil(p.sendConn)) - -proc `conn=`*(p: PubSubPeer, conn: Connection) = - if not(isNil(conn)): - trace "attaching send connection for peer", peer = p.id - p.sendConn = conn - p.onConnect.fire() - -proc conn*(p: PubSubPeer): Connection = - p.sendConn + not p.sendConn.isNil and not + (p.sendConn.closed or p.sendConn.atEof) proc recvObservers(p: PubSubPeer, msg: var RPCMsg) = # trigger hooks @@ -88,12 +84,13 @@ proc sendObservers(p: PubSubPeer, msg: var RPCMsg) = proc handle*(p: PubSubPeer, conn: Connection) {.async.} = logScope: peer = p.id + debug "starting pubsub read loop for peer", closed = conn.closed try: try: while not conn.atEof: trace "waiting for data", closed = conn.closed - let data = await conn.readLp(64 * 1024).wait(DefaultReadTimeout) + let data = await conn.readLp(64 * 1024) let digest = $(sha256.digest(data)) trace "read data from peer", data = data.shortLog if digest in p.recvdRpcCache: @@ -129,12 +126,14 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} = raise exc except CatchableError as exc: trace "Exception occurred in PubSubPeer.handle", exc = exc.msg - raise exc proc send*( p: PubSubPeer, msg: RPCMsg, timeout: Duration = DefaultSendTimeout) {.async.} = + + doAssert(not isNil(p), "pubsubpeer nil!") + logScope: peer = p.id rpcMsg = shortLog(msg) @@ -160,91 +159,55 @@ proc send*( libp2p_pubsub_skipped_sent_messages.inc(labelValues = [p.id]) return - proc sendToRemote() {.async.} = - logScope: - peer = p.id - rpcMsg = shortLog(msg) - - trace "about to send message" - - if not p.onConnect.isSet: - await p.onConnect.wait() - - if p.connected: # this can happen if the remote disconnected - trace "sending encoded msgs to peer" - - await p.sendConn.writeLp(encoded) - p.sentRpcCache.put(digest) - trace "sent pubsub message to remote" - - when defined(libp2p_expensive_metrics): - for x in mm.messages: - for t in x.topicIDs: - # metrics - libp2p_pubsub_sent_messages.inc(labelValues = [p.id, t]) - - let sendFut = sendToRemote() try: - await sendFut.wait(timeout) + trace "about to send message" + if not p.connected: + try: + await p.sendLock.acquire() + trace "no send connection, dialing peer" + # get a send connection if there is none + p.sendConn = await p.switch.dial( + p.peerId, p.codec) + + if not p.connected: + raise newException(CatchableError, "unable to get send pubsub stream") + + # install a reader on the send connection + asyncCheck p.handle(p.sendConn) + finally: + if p.sendLock.locked: + p.sendLock.release() + + trace "sending encoded msgs to peer" + await p.sendConn.writeLp(encoded).wait(timeout) + p.sentRpcCache.put(digest) + trace "sent pubsub message to remote" + + when defined(libp2p_expensive_metrics): + for x in mm.messages: + for t in x.topicIDs: + # metrics + libp2p_pubsub_sent_messages.inc(labelValues = [p.id, t]) + except CatchableError as exc: trace "unable to send to remote", exc = exc.msg - if not sendFut.finished: - sendFut.cancel() - if not(isNil(p.sendConn)): await p.sendConn.close() p.sendConn = nil - p.onConnect.clear() raise exc -proc sendSubOpts*(p: PubSubPeer, topics: seq[string], subscribe: bool) {.async.} = - trace "sending subscriptions", peer = p.id, subscribe, topicIDs = topics - - try: - await p.send(RPCMsg( - subscriptions: topics.mapIt(SubOpts(subscribe: subscribe, topic: it))), - # the long timeout is mostly for cases where - # the connection is flaky at the beggingin - timeout = 3.minutes) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "exception sending subscriptions", exc = exc.msg - -proc sendGraft*(p: PubSubPeer, topics: seq[string]) {.async.} = - trace "sending graft to peer", peer = p.id, topicIDs = topics - - try: - await p.send(RPCMsg(control: some( - ControlMessage(graft: topics.mapIt(ControlGraft(topicID: it))))), - timeout = 1.minutes) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "exception sending grafts", exc = exc.msg - -proc sendPrune*(p: PubSubPeer, topics: seq[string]) {.async.} = - trace "sending prune to peer", peer = p.id, topicIDs = topics - - try: - await p.send(RPCMsg(control: some( - ControlMessage(prune: topics.mapIt(ControlPrune(topicID: it))))), - timeout = 1.minutes) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "exception sending prunes", exc = exc.msg - proc `$`*(p: PubSubPeer): string = p.id -proc newPubSubPeer*(peerInfo: PeerInfo, - proto: string): PubSubPeer = +proc newPubSubPeer*(peerId: PeerID, + switch: Switch, + codec: string): PubSubPeer = new result - result.proto = proto - result.peerInfo = peerInfo + result.switch = switch + result.codec = codec + result.peerId = peerId result.sentRpcCache = newTimedCache[string](2.minutes) result.recvdRpcCache = newTimedCache[string](2.minutes) - result.onConnect = newAsyncEvent() result.topics = initHashSet[string]() + result.sendLock = newAsyncLock() diff --git a/libp2p/protocols/pubsub/rpc/message.nim b/libp2p/protocols/pubsub/rpc/message.nim index 29d700500..e49b54d9e 100644 --- a/libp2p/protocols/pubsub/rpc/message.nim +++ b/libp2p/protocols/pubsub/rpc/message.nim @@ -10,7 +10,8 @@ {.push raises: [Defect].} import chronicles, metrics, stew/[byteutils, endians2] -import ./messages, ./protobuf, +import ./messages, + ./protobuf, ../../../peerid, ../../../peerinfo, ../../../crypto/crypto, @@ -32,7 +33,7 @@ func defaultMsgIdProvider*(m: Message): string = proc sign*(msg: Message, p: PeerInfo): CryptoResult[seq[byte]] = ok((? p.privateKey.sign(PubSubPrefix & encodeMessage(msg))).getBytes()) -proc verify*(m: Message, p: PeerInfo): bool = +proc verify*(m: Message, p: PeerID): bool = if m.signature.len > 0 and m.key.len > 0: var msg = m msg.signature = @[] @@ -51,17 +52,17 @@ proc verify*(m: Message, p: PeerInfo): bool = proc init*( T: type Message, - p: PeerInfo, + peer: PeerInfo, data: seq[byte], topic: string, seqno: uint64, sign: bool = true): Message {.gcsafe, raises: [CatchableError, Defect].} = result = Message( - fromPeer: p.peerId, + fromPeer: peer.peerId, data: data, seqno: @(seqno.toBytesBE), # unefficient, fine for now topicIDs: @[topic]) - if sign and p.publicKey.isSome: - result.signature = sign(result, p).tryGet() - result.key = p.publicKey.get().getBytes().tryGet() + if sign and peer.publicKey.isSome: + result.signature = sign(result, peer).tryGet() + result.key = peer.publicKey.get().getBytes().tryGet() diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 1466a50c4..f8cea1996 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -30,11 +30,13 @@ type proc init*[T: SecureConn](C: type T, conn: Connection, peerInfo: PeerInfo, - observedAddr: Multiaddress): T = + observedAddr: Multiaddress, + timeout: Duration = DefaultConnectionTimeout): T = result = C(stream: conn, peerInfo: peerInfo, observedAddr: observedAddr, - closeEvent: conn.closeEvent) + closeEvent: conn.closeEvent, + timeout: timeout) result.initStream() method initStream*(s: SecureConn) = @@ -62,7 +64,7 @@ proc handleConn*(s: Secure, initiator: bool): Future[Connection] {.async, gcsafe.} = var sconn = await s.handshake(conn, initiator) if not isNil(sconn): - conn.closeEvent.wait() + conn.join() .addCallback do(udata: pointer = nil): asyncCheck sconn.close() diff --git a/libp2p/standard_setup.nim b/libp2p/standard_setup.nim index 688b6e7b3..d9d51ff21 100644 --- a/libp2p/standard_setup.nim +++ b/libp2p/standard_setup.nim @@ -1,16 +1,9 @@ -# compile time options here -const - libp2p_pubsub_sign {.booldefine.} = true - libp2p_pubsub_verify {.booldefine.} = true - import options, tables, chronos, bearssl, switch, peerid, peerinfo, stream/connection, multiaddress, crypto/crypto, transports/[transport, tcptransport], muxers/[muxer, mplex/mplex, mplex/types], - protocols/[identify, secure/secure], - protocols/pubsub/[pubsub, floodsub, gossipsub], - protocols/pubsub/rpc/message + protocols/[identify, secure/secure] import protocols/secure/noise, @@ -26,17 +19,12 @@ type proc newStandardSwitch*(privKey = none(PrivateKey), address = MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet(), - triggerSelf = false, - gossip = false, secureManagers: openarray[SecureProtocol] = [ # array cos order matters SecureProtocol.Secio, SecureProtocol.Noise, ], - verifySignature = libp2p_pubsub_verify, - sign = libp2p_pubsub_sign, transportFlags: set[ServerFlags] = {}, - msgIdProvider: MsgIdProvider = defaultMsgIdProvider, rng = newRng(), inTimeout: Duration = 5.minutes, outTimeout: Duration = 5.minutes): Switch = @@ -66,26 +54,11 @@ proc newStandardSwitch*(privKey = none(PrivateKey), of SecureProtocol.Secio: secureManagerInstances &= newSecio(rng, seckey).Secure - let pubSub = if gossip: - newPubSub(GossipSub, - peerInfo = peerInfo, - triggerSelf = triggerSelf, - verifySignature = verifySignature, - sign = sign, - msgIdProvider = msgIdProvider, - params = GossipSubParams.init()).PubSub - else: - newPubSub(FloodSub, - peerInfo = peerInfo, - triggerSelf = triggerSelf, - verifySignature = verifySignature, - sign = sign, - msgIdProvider = msgIdProvider).PubSub - - newSwitch( + let switch = newSwitch( peerInfo, transports, identify, muxers, - secureManagers = secureManagerInstances, - pubSub = some(pubSub)) + secureManagers = secureManagerInstances) + + return switch diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 6f0058070..00547d24b 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -143,8 +143,10 @@ proc initBufferStream*(s: BufferStream, trace "created bufferstream", oid = $s.oid proc newBufferStream*(handler: WriteHandler = nil, - size: int = DefaultBufferSize): BufferStream = + size: int = DefaultBufferSize, + timeout: Duration = DefaultConnectionTimeout): BufferStream = new result + result.timeout = timeout result.initBufferStream(handler, size) proc popFirst*(s: BufferStream): byte = diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 366bc44eb..e6d28cf26 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -45,7 +45,7 @@ template withExceptions(body: untyped) = raise exc except TransportIncompleteError: # for all intents and purposes this is an EOF - raise newLPStreamEOFError() + raise newLPStreamIncompleteError() except TransportLimitError: raise newLPStreamLimitError() except TransportUseClosedError: diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index 55461fd77..a5925d717 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import hashes +import hashes, oids import chronicles, chronos, metrics import lpstream, ../multiaddress, @@ -20,7 +20,7 @@ logScope: const ConnectionTrackerName* = "libp2p.connection" - DefaultConnectionTimeout* = 1.minutes + DefaultConnectionTimeout* = 5.minutes type TimeoutHandler* = proc(): Future[void] {.gcsafe.} @@ -73,8 +73,15 @@ method initStream*(s: Connection) = procCall LPStream(s).initStream() s.closeEvent = newAsyncEvent() + if isNil(s.timeoutHandler): + s.timeoutHandler = proc() {.async.} = + await s.close() + + trace "timeout", timeout = $s.timeout.millis doAssert(isNil(s.timerTaskFut)) - s.timerTaskFut = s.timeoutMonitor() + # doAssert(s.timeout > 0.millis) + if s.timeout > 0.millis: + s.timerTaskFut = s.timeoutMonitor() inc getConnectionTracker().opened diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 1c4269160..efd9f9440 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -115,8 +115,12 @@ proc readExactly*(s: LPStream, read += await s.readOnce(addr pbuffer[read], nbytes - read) if read < nbytes: - trace "incomplete data received", read - raise newLPStreamIncompleteError() + if s.atEof: + trace "couldn't read all bytes, stream EOF", expected = nbytes, read + raise newLPStreamEOFError() + else: + trace "couldn't read all bytes, incomplete data", expected = nbytes, read + raise newLPStreamIncompleteError() proc readLine*(s: LPStream, limit = 0, diff --git a/libp2p/switch.nim b/libp2p/switch.nim index fc48bf8dc..2aa17b5ae 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -25,12 +25,14 @@ import stream/connection, protocols/secure/secure, peerinfo, protocols/identify, - protocols/pubsub/pubsub, muxers/muxer, connmanager, peerid, errors +chronicles.formatIt(PeerInfo): $it +chronicles.formatIt(PeerID): $it + logScope: topics = "switch" @@ -44,9 +46,6 @@ declareCounter(libp2p_dialed_peers, "dialed peers") declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_upgrade, "peers failed upgrade") -const - MaxPubsubReconnectAttempts* = 10 - type NoPubSubException* = object of CatchableError @@ -77,14 +76,8 @@ type identity*: Identify streamHandler*: StreamHandler secureManagers*: seq[Secure] - pubSub*: Option[PubSub] - running: bool dialLock: Table[PeerID, AsyncLock] ConnEvents: Table[ConnEventKind, HashSet[ConnEventHandler]] - pubsubMonitors: Table[PeerId, Future[void]] - -proc newNoPubSubException(): ref NoPubSubException {.inline.} = - result = newException(NoPubSubException, "no pubsub provided!") proc addConnEventHandler*(s: Switch, handler: ConnEventHandler, kind: ConnEventKind) = @@ -111,23 +104,6 @@ proc triggerConnEvent(s: Switch, peerId: PeerID, event: ConnEvent) {.async, gcsa warn "exception in trigger ConnEvents", exc = exc.msg proc disconnect*(s: Switch, peerId: PeerID) {.async, gcsafe.} -proc subscribePeer*(s: Switch, peerId: PeerID) {.async, gcsafe.} -proc subscribePeerInternal(s: Switch, peerId: PeerID) {.async, gcsafe.} - -proc cleanupPubSubPeer(s: Switch, conn: Connection) {.async.} = - try: - await conn.closeEvent.wait() - trace "about to cleanup pubsub peer" - if s.pubSub.isSome: - let fut = s.pubsubMonitors.getOrDefault(conn.peerInfo.peerId) - if not(isNil(fut)) and not(fut.finished): - fut.cancel() - - await s.pubSub.get().unsubscribePeer(conn.peerInfo) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "exception cleaning pubsub peer", exc = exc.msg proc isConnected*(s: Switch, peerId: PeerID): bool = ## returns true if the peer has one or more @@ -295,7 +271,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = proc internalConnect(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]): Future[Connection] {.async.} = - logScope: peer = peerId + logScope: + peer = peerId if s.peerInfo.peerId == peerId: raise newException(CatchableError, "can't dial self!") @@ -353,12 +330,12 @@ proc internalConnect(s: Switch, libp2p_failed_upgrade.inc() raise exc - doAssert not isNil(upgraded), "checked in upgradeOutgoing" + doAssert not isNil(upgraded), "connection died after upgradeOutgoing" s.connManager.storeOutgoing(upgraded) conn = upgraded trace "dial successful", - oid = $conn.oid, + oid = $upgraded.oid, peerInfo = shortLog(upgraded.peerInfo) break finally: @@ -381,14 +358,31 @@ proc internalConnect(s: Switch, # unworthy and disconnects it raise newException(CatchableError, "Connection closed during handshake") - asyncCheck s.cleanupPubSubPeer(conn) - asyncCheck s.subscribePeer(peerId) - return conn proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} = discard await s.internalConnect(peerId, addrs) +proc negotiateStream(s: Switch, stream: Connection, proto: string): Future[Connection] {.async.} = + trace "Attempting to select remote", proto = proto, + streamOid = $stream.oid, + oid = $stream.oid + + if not await s.ms.select(stream, proto): + await stream.close() + raise newException(CatchableError, "Unable to select sub-protocol" & proto) + + return stream + +proc dial*(s: Switch, + peerId: PeerID, + proto: string): Future[Connection] {.async.} = + let stream = await s.connmanager.getMuxedStream(peerId) + if stream.isNil: + raise newException(CatchableError, "Couldn't get muxed stream") + + return await s.negotiateStream(stream, proto) + proc dial*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress], @@ -409,14 +403,7 @@ proc dial*(s: Switch, await conn.close() raise newException(CatchableError, "Couldn't get muxed stream") - trace "Attempting to select remote", proto = proto, - streamOid = $stream.oid, - oid = $conn.oid - if not await s.ms.select(stream, proto): - await stream.close() - raise newException(CatchableError, "Unable to select sub-protocol" & proto) - - return stream + return await s.negotiateStream(stream, proto) except CancelledError as exc: trace "dial canceled" await cleanup() @@ -458,21 +445,12 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = s.peerInfo.addrs[i] = t.ma # update peer's address startFuts.add(server) - if s.pubSub.isSome: - await s.pubSub.get().start() - debug "started libp2p node", peer = $s.peerInfo, addrs = s.peerInfo.addrs result = startFuts # listen for incoming connections proc stop*(s: Switch) {.async.} = trace "stopping switch" - # we want to report errors but we do not want to fail - # or crash here, cos we need to clean possibly MANY items - # and any following conn/transport won't be cleaned up - if s.pubSub.isSome: - await s.pubSub.get().stop() - # close and cleanup all connections await s.connManager.close() @@ -486,139 +464,6 @@ proc stop*(s: Switch) {.async.} = trace "switch stopped" -proc subscribePeerInternal(s: Switch, peerId: PeerID) {.async, gcsafe.} = - ## Subscribe to pub sub peer - ## - - if s.pubSub.isSome and not s.pubSub.get().connected(peerId): - trace "about to subscribe to pubsub peer", peer = peerId - var stream: Connection - try: - stream = await s.connManager.getMuxedStream(peerId) - if isNil(stream): - trace "unable to subscribe to peer", peer = peerId - return - - if not await s.ms.select(stream, s.pubSub.get().codec): - if not(isNil(stream)): - trace "couldn't select pubsub", codec = s.pubSub.get().codec - await stream.close() - return - - s.pubSub.get().subscribePeer(stream) - await stream.closeEvent.wait() - except CancelledError as exc: - if not(isNil(stream)): - await stream.close() - - raise exc - except CatchableError as exc: - trace "exception in subscribe to peer", peer = peerId, - exc = exc.msg - if not(isNil(stream)): - await stream.close() - -proc pubsubMonitor(s: Switch, peerId: PeerID) {.async.} = - ## while peer connected maintain a - ## pubsub connection as well - ## - - while s.isConnected(peerId): - try: - trace "subscribing to pubsub peer", peer = peerId - await s.subscribePeerInternal(peerId) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "exception in pubsub monitor", peer = peerId, exc = exc.msg - finally: - trace "sleeping before trying pubsub peer", peer = peerId - await sleepAsync(1.seconds) # allow the peer to cooldown - - trace "exiting pubsub monitor", peer = peerId - -proc subscribePeer*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} = - ## Waits until ``server`` is not closed. - ## - - var retFuture = newFuture[void]("stream.transport.server.join") - let pubsubFut = s.pubsubMonitors.mgetOrPut( - peerId, s.pubsubMonitor(peerId)) - - proc continuation(udata: pointer) {.gcsafe.} = - retFuture.complete() - - proc cancel(udata: pointer) {.gcsafe.} = - pubsubFut.removeCallback(continuation, cast[pointer](retFuture)) - - if not(pubsubFut.finished()): - pubsubFut.addCallback(continuation, cast[pointer](retFuture)) - retFuture.cancelCallback = cancel - else: - retFuture.complete() - - return retFuture - -proc subscribe*(s: Switch, topic: string, - handler: TopicHandler) {.async.} = - ## subscribe to a pubsub topic - ## - - if s.pubSub.isNone: - raise newNoPubSubException() - - await s.pubSub.get().subscribe(topic, handler) - -proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} = - ## unsubscribe from topics - ## - - if s.pubSub.isNone: - raise newNoPubSubException() - - await s.pubSub.get().unsubscribe(topics) - -proc unsubscribeAll*(s: Switch, topic: string) {.async.} = - ## unsubscribe from topics - if s.pubSub.isNone: - raise newNoPubSubException() - - await s.pubSub.get().unsubscribeAll(topic) - -proc publish*(s: Switch, - topic: string, - data: seq[byte], - timeout: Duration = InfiniteDuration): Future[int] {.async.} = - ## pubslish to pubsub topic - ## - - if s.pubSub.isNone: - raise newNoPubSubException() - - return await s.pubSub.get().publish(topic, data, timeout) - -proc addValidator*(s: Switch, - topics: varargs[string], - hook: ValidatorHandler) = - ## add validator - ## - - if s.pubSub.isNone: - raise newNoPubSubException() - - s.pubSub.get().addValidator(topics, hook) - -proc removeValidator*(s: Switch, - topics: varargs[string], - hook: ValidatorHandler) = - ## pubslish to pubsub topic - ## - - if s.pubSub.isNone: - raise newNoPubSubException() - - s.pubSub.get().removeValidator(topics, hook) - proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = var stream = await muxer.newStream() defer: @@ -654,10 +499,6 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = asyncCheck s.triggerConnEvent( peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: true)) - # try establishing a pubsub connection - asyncCheck s.cleanupPubSubPeer(muxer.connection) - asyncCheck s.subscribePeer(peerId) - except CancelledError as exc: await muxer.close() raise exc @@ -670,8 +511,7 @@ proc newSwitch*(peerInfo: PeerInfo, transports: seq[Transport], identity: Identify, muxers: Table[string, MuxerProvider], - secureManagers: openarray[Secure] = [], - pubSub: Option[PubSub] = none(PubSub)): Switch = + secureManagers: openarray[Secure] = []): Switch = if secureManagers.len == 0: raise (ref CatchableError)(msg: "Provide at least one secure manager") @@ -704,24 +544,21 @@ proc newSwitch*(peerInfo: PeerInfo, val.muxerHandler = proc(muxer: Muxer): Future[void] = s.muxerHandler(muxer) - if pubSub.isSome: - result.pubSub = pubSub - result.mount(pubSub.get()) - -proc isConnected*(s: Switch, peerInfo: PeerInfo): bool {.deprecated: "Use PeerID version".} = +proc isConnected*(s: Switch, peerInfo: PeerInfo): bool + {.deprecated: "Use PeerID version".} = not isNil(peerInfo) and isConnected(s, peerInfo.peerId) -proc disconnect*(s: Switch, peerInfo: PeerInfo): Future[void] {.deprecated: "Use PeerID version", gcsafe.} = +proc disconnect*(s: Switch, peerInfo: PeerInfo): Future[void] + {.deprecated: "Use PeerID version", gcsafe.} = disconnect(s, peerInfo.peerId) -proc connect*(s: Switch, peerInfo: PeerInfo): Future[void] {.deprecated: "Use PeerID version".} = +proc connect*(s: Switch, peerInfo: PeerInfo): Future[void] + {.deprecated: "Use PeerID version".} = connect(s, peerInfo.peerId, peerInfo.addrs) proc dial*(s: Switch, peerInfo: PeerInfo, proto: string): - Future[Connection] {.deprecated: "Use PeerID version".} = + Future[Connection] + {.deprecated: "Use PeerID version".} = dial(s, peerInfo.peerId, peerInfo.addrs, proto) - -proc subscribePeer*(s: Switch, peerInfo: PeerInfo): Future[void] {.deprecated: "Use PeerID version", gcsafe.} = - subscribePeer(s, peerInfo.peerId) diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index 3d8fdad16..870cfaf56 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -29,9 +29,9 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = # turn things deterministic # this is for testing purposes only var ceil = 15 - let fsub = cast[FloodSub](sender.pubSub.get()) + let fsub = cast[FloodSub](sender) while not fsub.floodsub.hasKey(key) or - not fsub.floodsub.hasPeerID(key, receiver.peerInfo.id): + not fsub.floodsub.hasPeerID(key, receiver.peerInfo.peerId): await sleepAsync(100.millis) dec ceil doAssert(ceil > 0, "waitSub timeout!") @@ -43,7 +43,7 @@ suite "FloodSub": check tracker.isLeaked() == false test "FloodSub basic publish/subscribe A -> B": - proc runTests(): Future[bool] {.async.} = + proc runTests() {.async.} = var completionFut = newFuture[bool]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" @@ -51,19 +51,32 @@ suite "FloodSub": let nodes = generateNodes(2) + + # start switches nodesFut = await allFinished( - nodes[0].start(), - nodes[1].start() + nodes[0].switch.start(), + nodes[1].switch.start(), ) - let subscribes = await subscribeNodes(nodes) + # start pubsub + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) await nodes[1].subscribe("foobar", handler) await waitSub(nodes[0], nodes[1], "foobar") check (await nodes[0].publish("foobar", "Hello!".toBytes())) > 0 + check (await completionFut.wait(5.seconds)) == true - result = await completionFut.wait(5.seconds) + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) await allFuturesThrowing( nodes[0].stop(), @@ -71,53 +84,80 @@ suite "FloodSub": ) await allFuturesThrowing(nodesFut.concat()) - await allFuturesThrowing(subscribes) - check: - waitFor(runTests()) == true + waitFor(runTests()) test "FloodSub basic publish/subscribe B -> A": - proc runTests(): Future[bool] {.async.} = + proc runTests() {.async.} = var completionFut = newFuture[bool]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" completionFut.complete(true) - var nodes = generateNodes(2) - var awaiters: seq[Future[void]] - awaiters.add((await nodes[0].start())) - awaiters.add((await nodes[1].start())) + let + nodes = generateNodes(2) - let subscribes = await subscribeNodes(nodes) + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) + + # start pubsubcon + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) await nodes[0].subscribe("foobar", handler) await waitSub(nodes[1], nodes[0], "foobar") check (await nodes[1].publish("foobar", "Hello!".toBytes())) > 0 - result = await completionFut.wait(5.seconds) + check (await completionFut.wait(5.seconds)) == true - await allFuturesThrowing(nodes[0].stop(), nodes[1].stop()) + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaiters) + await allFuturesThrowing( + nodes[0].stop(), + nodes[1].stop() + ) - check: - waitFor(runTests()) == true + await allFuturesThrowing(nodesFut) + + waitFor(runTests()) test "FloodSub validation should succeed": - proc runTests(): Future[bool] {.async.} = + proc runTests() {.async.} = var handlerFut = newFuture[bool]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" handlerFut.complete(true) - var nodes = generateNodes(2) - var awaiters: seq[Future[void]] - awaiters.add((await nodes[0].start())) - awaiters.add((await nodes[1].start())) + let + nodes = generateNodes(2) + + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) + + # start pubsubcon + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) - let subscribes = await subscribeNodes(nodes) await nodes[1].subscribe("foobar", handler) await waitSub(nodes[0], nodes[1], "foobar") @@ -131,30 +171,44 @@ suite "FloodSub": nodes[1].addValidator("foobar", validator) check (await nodes[0].publish("foobar", "Hello!".toBytes())) > 0 - check (await handlerFut) == true + + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) + await allFuturesThrowing( nodes[0].stop(), - nodes[1].stop()) + nodes[1].stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaiters) - result = true + await allFuturesThrowing(nodesFut) - check: - waitFor(runTests()) == true + waitFor(runTests()) test "FloodSub validation should fail": - proc runTests(): Future[bool] {.async.} = + proc runTests() {.async.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check false # if we get here, it should fail - var nodes = generateNodes(2) - var awaiters: seq[Future[void]] - awaiters.add((await nodes[0].start())) - awaiters.add((await nodes[1].start())) + let + nodes = generateNodes(2) - let subscribes = await subscribeNodes(nodes) + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) + + # start pubsubcon + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) await nodes[1].subscribe("foobar", handler) await waitSub(nodes[0], nodes[1], "foobar") @@ -168,30 +222,44 @@ suite "FloodSub": discard await nodes[0].publish("foobar", "Hello!".toBytes()) + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) + await allFuturesThrowing( nodes[0].stop(), - nodes[1].stop()) + nodes[1].stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaiters) - result = true + await allFuturesThrowing(nodesFut) - check: - waitFor(runTests()) == true + waitFor(runTests()) test "FloodSub validation one fails and one succeeds": - proc runTests(): Future[bool] {.async.} = + proc runTests() {.async.} = var handlerFut = newFuture[bool]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foo" handlerFut.complete(true) - var nodes = generateNodes(2) - var awaiters: seq[Future[void]] - awaiters.add((await nodes[0].start())) - awaiters.add((await nodes[1].start())) + let + nodes = generateNodes(2) - let subscribes = await subscribeNodes(nodes) + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) + + # start pubsubcon + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) await nodes[1].subscribe("foo", handler) await waitSub(nodes[0], nodes[1], "foo") await nodes[1].subscribe("bar", handler) @@ -210,57 +278,21 @@ suite "FloodSub": check (await nodes[0].publish("bar", "Hello!".toBytes())) > 0 await allFuturesThrowing( - nodes[0].stop(), - nodes[1].stop()) - - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaiters) - result = true - - check: - waitFor(runTests()) == true - - test "FloodSub publish should fail on timeout": - proc runTests(): Future[bool] {.async.} = - proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = - discard - - var nodes = generateNodes(2) - var awaiters: seq[Future[void]] - awaiters.add((await nodes[0].start())) - awaiters.add((await nodes[1].start())) - - let subscribes = await subscribeNodes(nodes) - await nodes[1].subscribe("foobar", handler) - await waitSub(nodes[0], nodes[1], "foobar") - - let pubsub = nodes[0].pubSub.get() - let peer = pubsub.peers[nodes[1].peerInfo.id] - - peer.conn = Connection(newBufferStream( - proc (data: seq[byte]) {.async, gcsafe.} = - await sleepAsync(10.seconds) - ,size = 0)) - - let in10millis = Moment.fromNow(10.millis) - let sent = await nodes[0].publish("foobar", "Hello!".toBytes(), 10.millis) - - check Moment.now() >= in10millis - check sent == 0 + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) await allFuturesThrowing( nodes[0].stop(), - nodes[1].stop()) + nodes[1].stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaiters) - result = true + await allFuturesThrowing(nodesFut) - check: - waitFor(runTests()) == true + waitFor(runTests()) test "FloodSub multiple peers, no self trigger": - proc runTests(): Future[bool] {.async.} = + proc runTests() {.async.} = var runs = 10 var futs = newSeq[(Future[void], TopicHandler, ref int)](runs) @@ -279,15 +311,12 @@ suite "FloodSub": counter ) - var nodes: seq[Switch] = newSeq[Switch]() - for i in 0..= in10millis - check sent == 0 + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) await allFuturesThrowing( nodes[0].stop(), - nodes[1].stop()) + nodes[1].stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaiters) - result = true + await allFuturesThrowing(nodesFut.concat()) - check: - waitFor(runTests()) == true + waitFor(runTests()) test "e2e - GossipSub should add remote peer topic subscriptions": - proc testBasicGossipSub(): Future[bool] {.async.} = + proc testBasicGossipSub() {.async.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = discard - var nodes: seq[Switch] = newSeq[Switch]() - for i in 0..<2: - nodes.add newStandardSwitch(gossip = true, - secureManagers = [SecureProtocol.Noise]) + let + nodes = generateNodes( + 2, + gossip = true, + secureManagers = [SecureProtocol.Noise]) - var awaitters: seq[Future[void]] - for node in nodes: - awaitters.add(await node.start()) + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) + + # start pubsub + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) - let subscribes = await subscribeNodes(nodes) await nodes[1].subscribe("foobar", handler) await sleepAsync(10.seconds) - let gossip1 = GossipSub(nodes[0].pubSub.get()) - let gossip2 = GossipSub(nodes[1].pubSub.get()) + let gossip1 = GossipSub(nodes[0]) + let gossip2 = GossipSub(nodes[1]) check: "foobar" in gossip2.topics "foobar" in gossip1.gossipsub - gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) + gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.peerId) - await allFuturesThrowing(nodes.mapIt(it.stop())) + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaitters) + await allFuturesThrowing( + nodes[0].stop(), + nodes[1].stop() + ) - result = true + await allFuturesThrowing(nodesFut.concat()) - check: - waitFor(testBasicGossipSub()) == true + waitFor(testBasicGossipSub()) test "e2e - GossipSub should add remote peer topic subscriptions if both peers are subscribed": - proc testBasicGossipSub(): Future[bool] {.async.} = + proc testBasicGossipSub() {.async.} = proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = discard - var nodes: seq[Switch] = newSeq[Switch]() - for i in 0..<2: - nodes.add newStandardSwitch(gossip = true, secureManagers = [SecureProtocol.Secio]) + let + nodes = generateNodes( + 2, + gossip = true, + secureManagers = [SecureProtocol.Secio]) - var awaitters: seq[Future[void]] - for node in nodes: - awaitters.add(await node.start()) + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) - let subscribes = await subscribeNodes(nodes) + # start pubsub + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) await nodes[0].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler) @@ -342,8 +373,8 @@ suite "GossipSub": await allFuturesThrowing(subs) let - gossip1 = GossipSub(nodes[0].pubSub.get()) - gossip2 = GossipSub(nodes[1].pubSub.get()) + gossip1 = GossipSub(nodes[0]) + gossip2 = GossipSub(nodes[1]) check: "foobar" in gossip1.topics @@ -352,35 +383,53 @@ suite "GossipSub": "foobar" in gossip1.gossipsub "foobar" in gossip2.gossipsub - gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) or - gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id) + gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.peerId) or + gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.peerId) - gossip2.gossipsub.hasPeerID("foobar", gossip1.peerInfo.id) or - gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.id) + gossip2.gossipsub.hasPeerID("foobar", gossip1.peerInfo.peerId) or + gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.peerId) - await allFuturesThrowing(nodes.mapIt(it.stop())) + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaitters) + await allFuturesThrowing( + nodes[0].stop(), + nodes[1].stop() + ) - result = true + await allFuturesThrowing(nodesFut.concat()) - check: - waitFor(testBasicGossipSub()) == true + waitFor(testBasicGossipSub()) test "e2e - GossipSub send over fanout A -> B": - proc runTests(): Future[bool] {.async.} = + proc runTests() {.async.} = var passed = newFuture[void]() proc handler(topic: string, data: seq[byte]) {.async, gcsafe.} = check topic == "foobar" passed.complete() - var nodes = generateNodes(2, true) - var wait = newSeq[Future[void]]() - wait.add(await nodes[0].start()) - wait.add(await nodes[1].start()) + let + nodes = generateNodes( + 2, + gossip = true, + secureManagers = [SecureProtocol.Secio]) - let subscribes = await subscribeNodes(nodes) + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) + + # start pubsub + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) await nodes[1].subscribe("foobar", handler) await waitSub(nodes[0], nodes[1], "foobar") @@ -393,18 +442,19 @@ suite "GossipSub": obs2 = PubSubObserver(onSend: proc(peer: PubSubPeer; msgs: var RPCMsg) = inc observed ) - nodes[1].pubsub.get().addObserver(obs1) - nodes[0].pubsub.get().addObserver(obs2) + + # nodes[1].addObserver(obs1) + # nodes[0].addObserver(obs2) tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 - var gossip1: GossipSub = GossipSub(nodes[0].pubSub.get()) - var gossip2: GossipSub = GossipSub(nodes[1].pubSub.get()) + var gossip1: GossipSub = GossipSub(nodes[0]) + var gossip2: GossipSub = GossipSub(nodes[1]) check: "foobar" in gossip1.gossipsub - gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.id) - not gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id) + gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.peerId) + not gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.peerId) await passed.wait(2.seconds) @@ -413,14 +463,20 @@ suite "GossipSub": await nodes[0].stop() await nodes[1].stop() - await allFuturesThrowing(subscribes) - await allFuturesThrowing(wait) + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) - check observed == 2 - result = true + await allFuturesThrowing( + nodes[0].stop(), + nodes[1].stop() + ) - check: - waitFor(runTests()) == true + await allFuturesThrowing(nodesFut.concat()) + # check observed == 2 + + waitFor(runTests()) test "e2e - GossipSub send over mesh A -> B": proc runTests(): Future[bool] {.async.} = @@ -429,16 +485,26 @@ suite "GossipSub": check topic == "foobar" passed.complete(true) - var nodes = generateNodes(2, true) - var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get()) - gossipSub1.parameters.floodPublish = false - var gossipSub2: GossipSub = GossipSub(nodes[1].pubSub.get()) - gossipSub2.parameters.floodPublish = false - var wait: seq[Future[void]] - wait.add(await nodes[0].start()) - wait.add(await nodes[1].start()) + let + nodes = generateNodes( + 2, + gossip = true, + secureManagers = [SecureProtocol.Secio]) - let subscribes = await subscribeNodes(nodes) + # start switches + nodesFut = await allFinished( + nodes[0].switch.start(), + nodes[1].switch.start(), + ) + + # start pubsub + await allFuturesThrowing( + allFinished( + nodes[0].start(), + nodes[1].start(), + )) + + await subscribeNodes(nodes) await nodes[0].subscribe("foobar", handler) await nodes[1].subscribe("foobar", handler) @@ -448,41 +514,42 @@ suite "GossipSub": result = await passed - var gossip1: GossipSub = GossipSub(nodes[0].pubSub.get()) - var gossip2: GossipSub = GossipSub(nodes[1].pubSub.get()) + var gossip1: GossipSub = GossipSub(nodes[0]) + var gossip2: GossipSub = GossipSub(nodes[1]) check: "foobar" in gossip1.gossipsub "foobar" in gossip2.gossipsub - gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id) - not gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.id) - gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.id) - not gossip2.fanout.hasPeerID("foobar", gossip1.peerInfo.id) + gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.peerId) + not gossip1.fanout.hasPeerID("foobar", gossip2.peerInfo.peerId) + gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.peerId) + not gossip2.fanout.hasPeerID("foobar", gossip1.peerInfo.peerId) - await nodes[0].stop() - await nodes[1].stop() + await allFuturesThrowing( + nodes[0].switch.stop(), + nodes[1].switch.stop() + ) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(wait) + await allFuturesThrowing( + nodes[0].stop(), + nodes[1].stop() + ) + + await allFuturesThrowing(nodesFut.concat()) check: waitFor(runTests()) == true test "e2e - GossipSub with multiple peers": - proc runTests(): Future[bool] {.async.} = - var nodes: seq[Switch] = newSeq[Switch]() - var awaitters: seq[Future[void]] + proc runTests() {.async.} = var runs = 10 - for i in 0..= 1 for node in nodes: - var gossip: GossipSub = GossipSub(node.pubSub.get()) + var gossip = GossipSub(node) + check: "foobar" in gossip.gossipsub gossip.fanout.len == 0 gossip.mesh["foobar"].len > 0 - await allFuturesThrowing(nodes.mapIt(it.stop())) + await allFuturesThrowing( + nodes.mapIt( + allFutures( + it.stop(), + it.switch.stop()))) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaitters) - result = true + await allFuturesThrowing(nodesFut) - check: - waitFor(runTests()) == true + waitFor(runTests()) test "e2e - GossipSub with multiple peers (sparse)": - proc runTests(): Future[bool] {.async.} = - var nodes: seq[Switch] = newSeq[Switch]() - var awaitters: seq[Future[void]] + proc runTests() {.async.} = var runs = 10 - for i in 0..= 1 for node in nodes: - var gossip: GossipSub = GossipSub(node.pubSub.get()) + var gossip = GossipSub(node) check: "foobar" in gossip.gossipsub gossip.fanout.len == 0 gossip.mesh["foobar"].len > 0 - await allFuturesThrowing(nodes.mapIt(it.stop())) + await allFuturesThrowing( + nodes.mapIt( + allFutures( + it.stop(), + it.switch.stop()))) - await allFuturesThrowing(subscribes) - await allFuturesThrowing(awaitters) - result = true + await allFuturesThrowing(nodesFut) - check: - waitFor(runTests()) == true + waitFor(runTests()) diff --git a/tests/pubsub/testmessage.nim b/tests/pubsub/testmessage.nim index 571a0566c..8a70dd2dd 100644 --- a/tests/pubsub/testmessage.nim +++ b/tests/pubsub/testmessage.nim @@ -16,4 +16,4 @@ suite "Message": peer = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) msg = Message.init(peer, @[], "topic", seqno, sign = true) - check verify(msg, peer) + check verify(msg, peer.peerId) diff --git a/tests/pubsub/utils.nim b/tests/pubsub/utils.nim index 0bc13cb38..69229a60f 100644 --- a/tests/pubsub/utils.nim +++ b/tests/pubsub/utils.nim @@ -1,27 +1,65 @@ -import random, options +# compile time options here +const + libp2p_pubsub_sign {.booldefine.} = true + libp2p_pubsub_verify {.booldefine.} = true + +import random import chronos -import ../../libp2p/standard_setup -import ../../libp2p/protocols/pubsub/gossipsub +import ../../libp2p/[standard_setup, + protocols/pubsub/pubsub, + protocols/pubsub/floodsub, + protocols/pubsub/gossipsub, + protocols/secure/secure] + export standard_setup randomize() -proc generateNodes*(num: Natural, gossip: bool = false): seq[Switch] = - for i in 0..