From fec507e7559d7f26b750cffa29a373e03fd42ecd Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Thu, 9 Jul 2020 02:06:26 +0900 Subject: [PATCH 01/23] Add peers back to gossipsub table, slow down heartbeat (#256) * Add peers back to gossipsub table, slow down heartbeat * exclude on unsub from mesh and fanout --- libp2p/protocols/pubsub/gossipsub.nim | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index f6a4a2eb7..d2efad55b 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -168,6 +168,7 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = # send a graft message to the peer await p.sendPrune(@[topic]) g.mesh[topic].excl(id) + g.gossipsub[topic].incl(id) libp2p_gossipsub_peers_per_topic_gossipsub .set(g.gossipsub.getOrDefault(topic).len.int64, @@ -274,7 +275,7 @@ proc heartbeat(g: GossipSub) {.async.} = except CatchableError as exc: trace "exception ocurred in gossipsub heartbeat", exc = exc.msg - await sleepAsync(1.seconds) + await sleepAsync(5.seconds) method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = ## handle peer disconnects @@ -323,6 +324,10 @@ method subscribeTopic*(g: GossipSub, trace "removing subscription for topic", peer = peerId, name = topic # unsubscribe remote peer from the topic g.gossipsub[topic].excl(peerId) + if peerId in g.mesh.getOrDefault(topic): + g.mesh[topic].excl(peerId) + if peerId in g.fanout.getOrDefault(topic): + g.fanout[topic].excl(peerId) libp2p_gossipsub_peers_per_topic_gossipsub .set(g.gossipsub[topic].len.int64, labelValues = [topic]) @@ -362,6 +367,7 @@ proc handlePrune(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) = if prune.topicID in g.mesh: g.mesh[prune.topicID].excl(peer.id) + g.gossipsub[prune.topicID].incl(peer.id) libp2p_gossipsub_peers_per_topic_mesh .set(g.mesh[prune.topicID].len.int64, labelValues = [prune.topicID]) From 4698f41a9111227e8572ddd029849f64d5b67ada Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Thu, 9 Jul 2020 12:23:03 +0900 Subject: [PATCH 02/23] Remove stacktrace logging from pubsub connect --- libp2p/protocols/pubsub/pubsub.nim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index d30bec7bc..85ca0da92 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -129,7 +129,7 @@ proc getPeer(p: PubSub, # create new pubsub peer let peer = newPubSubPeer(peerInfo, proto) - trace "created new pubsub peer", peerId = peer.id, stack = getStackTrace() + trace "created new pubsub peer", peerId = peer.id p.peers[peer.id] = peer peer.observers = p.observers From 4bcb567d47d80cfbd2b5ff5b5642945913a45083 Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Thu, 9 Jul 2020 12:34:36 +0900 Subject: [PATCH 03/23] fix gossip tests --- libp2p/protocols/pubsub/gossipsub.nim | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index d2efad55b..9494b1872 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -168,7 +168,8 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = # send a graft message to the peer await p.sendPrune(@[topic]) g.mesh[topic].excl(id) - g.gossipsub[topic].incl(id) + if topic in g.gossipsub: + g.gossipsub[topic].incl(id) libp2p_gossipsub_peers_per_topic_gossipsub .set(g.gossipsub.getOrDefault(topic).len.int64, From 9b8b159abb34e5e2f894a602ddb50e9efb12f1fb Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Thu, 9 Jul 2020 13:19:34 +0900 Subject: [PATCH 04/23] Remove other spurious getStacktrace in pubsub traces --- libp2p/protocols/pubsub/pubsub.nim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 85ca0da92..e6e7d2ef5 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -63,7 +63,7 @@ method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} = ## handle peer disconnects ## if peer.id in p.peers: - trace "deleting peer", peer = peer.id, stack = getStackTrace() + trace "deleting peer", peer = peer.id p.peers[peer.id] = nil p.peers.del(peer.id) From f9e0a1f069c836195e66dfb551631f9e2be43c21 Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Thu, 9 Jul 2020 13:56:59 +0900 Subject: [PATCH 05/23] CI fix handleDisconnect (pubsub) --- libp2p/protocols/pubsub/pubsub.nim | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index e6e7d2ef5..dbfe012ba 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -62,9 +62,8 @@ type method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} = ## handle peer disconnects ## - if peer.id in p.peers: + if not isNil(peer.peerInfo) and peer.id in p.peers: trace "deleting peer", peer = peer.id - p.peers[peer.id] = nil p.peers.del(peer.id) # metrics From 4e12d0d97ad1b4fb6e04320e87f2a739a0cc557d Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Thu, 9 Jul 2020 17:20:45 +0900 Subject: [PATCH 06/23] nil check peer before disconnect --- libp2p/protocols/pubsub/floodsub.nim | 3 ++- libp2p/protocols/pubsub/gossipsub.nim | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index f629270b0..c83c2664d 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -138,7 +138,8 @@ method publish*(f: FloodSub, let (published, failed) = await f.sendHelper(f.floodsub.getOrDefault(topic), @[msg]) for p in failed: let peer = f.peers.getOrDefault(p) - f.handleDisconnect(peer) # cleanup failed peers + if not isNil(peer): + f.handleDisconnect(peer) # cleanup failed peers libp2p_pubsub_messages_published.inc(labelValues = [topic]) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 9494b1872..0fa25d25e 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -450,7 +450,7 @@ method rpcHandler*(g: GossipSub, let (published, failed) = await g.sendHelper(toSendPeers, m.messages) for p in failed: let peer = g.peers.getOrDefault(p) - if not(isNil(peer)): + if not isNil(peer): g.handleDisconnect(peer) # cleanup failed peers trace "forwared message to peers", peers = published.len @@ -523,7 +523,8 @@ method publish*(g: GossipSub, let (published, failed) = await g.sendHelper(peers, @[msg]) for p in failed: let peer = g.peers.getOrDefault(p) - g.handleDisconnect(peer) # cleanup failed peers + if not isNil(peer): + g.handleDisconnect(peer) # cleanup failed peers if published.len > 0: libp2p_pubsub_messages_published.inc(labelValues = [topic]) From 45c089ff0d81938a51b97a82c4504871a9c286fb Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Thu, 9 Jul 2020 10:53:19 +0200 Subject: [PATCH 07/23] noise updates (#255) * clear secrets explicitly * simplify keygen * avoid some trivial memory allocations * fix little endian encoding of nonce --- libp2p/crypto/curve25519.nim | 13 +- libp2p/crypto/ecnist.nim | 10 +- libp2p/protocols/secure/noise.nim | 277 +++++++++++++++-------------- libp2p/protocols/secure/secure.nim | 7 +- libp2p/stream/streamseq.nim | 12 ++ 5 files changed, 173 insertions(+), 146 deletions(-) diff --git a/libp2p/crypto/curve25519.nim b/libp2p/crypto/curve25519.nim index fbb84a3b3..18580d790 100644 --- a/libp2p/crypto/curve25519.nim +++ b/libp2p/crypto/curve25519.nim @@ -105,10 +105,13 @@ proc mulgen*(_: type[Curve25519], dst: var Curve25519Key, point: Curve25519Key) proc public*(private: Curve25519Key): Curve25519Key = Curve25519.mulgen(result, private) -proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Result[Curve25519Key, Curve25519Error] = +proc random*(_: type[Curve25519Key], rng: var BrHmacDrbgContext): Curve25519Key = var res: Curve25519Key let defaultBrEc = brEcGetDefault() - if brEcKeygen(addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) != Curve25519KeySize: - err(Curver25519GenError) - else: - ok(res) + let len = brEcKeygen( + addr rng.vtable, defaultBrEc, nil, addr res[0], EC_curve25519) + # Per bearssl documentation, the keygen only fails if the curve is + # unrecognised - + doAssert len == Curve25519KeySize, "Could not generate curve" + + res diff --git a/libp2p/crypto/ecnist.nim b/libp2p/crypto/ecnist.nim index c33ea90c8..30475ef4b 100644 --- a/libp2p/crypto/ecnist.nim +++ b/libp2p/crypto/ecnist.nim @@ -39,11 +39,11 @@ const type EcPrivateKey* = ref object - buffer*: seq[byte] + buffer*: array[BR_EC_KBUF_PRIV_MAX_SIZE, byte] key*: BrEcPrivateKey EcPublicKey* = ref object - buffer*: seq[byte] + buffer*: array[BR_EC_KBUF_PUB_MAX_SIZE, byte] key*: BrEcPublicKey EcKeyPair* = object @@ -237,7 +237,6 @@ proc random*( ## secp521r1). var ecimp = brEcGetDefault() var res = new EcPrivateKey - res.buffer = newSeq[byte](BR_EC_KBUF_PRIV_MAX_SIZE) if brEcKeygen(addr rng.vtable, ecimp, addr res.key, addr res.buffer[0], cast[cint](kind)) == 0: @@ -254,7 +253,6 @@ proc getKey*(seckey: EcPrivateKey): EcResult[EcPublicKey] = if seckey.key.curve in EcSupportedCurvesCint: var length = getPublicKeyLength(cast[EcCurveKind](seckey.key.curve)) var res = new EcPublicKey - res.buffer = newSeq[byte](length) if brEcComputePublicKey(ecimp, addr res.key, addr res.buffer[0], unsafeAddr seckey.key) == 0: err(EcKeyIncorrectError) @@ -621,7 +619,6 @@ proc init*(key: var EcPrivateKey, data: openarray[byte]): Result[void, Asn1Error if checkScalar(raw.toOpenArray(), curve) == 1'u32: key = new EcPrivateKey - key.buffer = newSeq[byte](raw.length) copyMem(addr key.buffer[0], addr raw.buffer[raw.offset], raw.length) key.key.x = cast[ptr cuchar](addr key.buffer[0]) key.key.xlen = raw.length @@ -681,7 +678,6 @@ proc init*(pubkey: var EcPublicKey, data: openarray[byte]): Result[void, Asn1Err if checkPublic(raw.toOpenArray(), curve) != 0: pubkey = new EcPublicKey - pubkey.buffer = newSeq[byte](raw.length) copyMem(addr pubkey.buffer[0], addr raw.buffer[raw.offset], raw.length) pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0]) pubkey.key.qlen = raw.length @@ -769,7 +765,6 @@ proc initRaw*(key: var EcPrivateKey, data: openarray[byte]): bool = if checkScalar(data, curve) == 1'u32: let length = len(data) key = new EcPrivateKey - key.buffer = newSeq[byte](length) copyMem(addr key.buffer[0], unsafeAddr data[0], length) key.key.x = cast[ptr cuchar](addr key.buffer[0]) key.key.xlen = length @@ -801,7 +796,6 @@ proc initRaw*(pubkey: var EcPublicKey, data: openarray[byte]): bool = if checkPublic(data, curve) != 0: let length = len(data) pubkey = new EcPublicKey - pubkey.buffer = newSeq[byte](length) copyMem(addr pubkey.buffer[0], unsafeAddr data[0], length) pubkey.key.q = cast[ptr cuchar](addr pubkey.buffer[0]) pubkey.key.qlen = length diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 929104bd2..fa14674bb 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -12,15 +12,13 @@ import chronicles import bearssl import stew/[endians2, byteutils] import nimcrypto/[utils, sha2, hmac] -import ../../stream/lpstream +import ../../stream/[connection, streamseq] import ../../peerid import ../../peerinfo import ../../protobuf/minprotobuf import ../../utility -import ../../stream/lpstream import secure, - ../../crypto/[crypto, chacha20poly1305, curve25519, hkdf], - ../../stream/bufferstream + ../../crypto/[crypto, chacha20poly1305, curve25519, hkdf] logScope: topics = "noise" @@ -34,7 +32,7 @@ const ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256" # Empty is a special value which indicates k has not yet been initialized. - EmptyKey: ChaChaPolyKey = [0.byte, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + EmptyKey = default(ChaChaPolyKey) NonceMax = uint64.high - 1 # max is reserved NoiseSize = 32 MaxPlainSize = int(uint16.high - NoiseSize - ChaChaPolyTag.len) @@ -72,7 +70,7 @@ type Noise* = ref object of Secure rng: ref BrHmacDrbgContext localPrivateKey: PrivateKey - localPublicKey: PublicKey + localPublicKey: seq[byte] noiseKeys: KeyPair commonPrologue: seq[byte] outgoing: bool @@ -89,7 +87,7 @@ type # Utility proc genKeyPair(rng: var BrHmacDrbgContext): KeyPair = - result.privateKey = Curve25519Key.random(rng).tryGet() + result.privateKey = Curve25519Key.random(rng) result.publicKey = result.privateKey.public() proc hashProtocol(name: string): MDigest[256] = @@ -110,12 +108,11 @@ proc dh(priv: Curve25519Key, pub: Curve25519Key): Curve25519Key = proc hasKey(cs: CipherState): bool = cs.k != EmptyKey -proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] = +proc encryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = var tag: ChaChaPolyTag nonce: ChaChaPolyNonce - np = cast[ptr uint64](addr nonce[4]) - np[] = state.n + nonce[4..<12] = toBytesLE(state.n) result = @data ChaChaPoly.encrypt(state.k, nonce, tag, result, ad) inc state.n @@ -124,13 +121,12 @@ proc encryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] result &= tag trace "encryptWithAd", tag = byteutils.toHex(tag), data = result.shortLog, nonce = state.n - 1 -proc decryptWithAd(state: var CipherState, ad, data: openarray[byte]): seq[byte] = +proc decryptWithAd(state: var CipherState, ad, data: openArray[byte]): seq[byte] = var - tagIn = data[^ChaChaPolyTag.len..data.high].intoChaChaPolyTag - tagOut = tagIn + tagIn = data.toOpenArray(data.len - ChaChaPolyTag.len, data.high).intoChaChaPolyTag + tagOut: ChaChaPolyTag nonce: ChaChaPolyNonce - np = cast[ptr uint64](addr nonce[4]) - np[] = state.n + nonce[4..<12] = toBytesLE(state.n) result = data[0..(data.high - ChaChaPolyTag.len)] ChaChaPoly.decrypt(state.k, nonce, tagOut, result, ad) trace "decryptWithAd", tagIn = tagIn.shortLog, tagOut = tagOut.shortLog, nonce = state.n @@ -156,7 +152,7 @@ proc mixKey(ss: var SymmetricState, ikm: ChaChaPolyKey) = ss.cs = CipherState(k: temp_keys[1]) trace "mixKey", key = ss.cs.k.shortLog -proc mixHash(ss: var SymmetricState; data: openarray[byte]) = +proc mixHash(ss: var SymmetricState; data: openArray[byte]) = var ctx: sha256 ctx.init() ctx.update(ss.h.data) @@ -165,7 +161,7 @@ proc mixHash(ss: var SymmetricState; data: openarray[byte]) = trace "mixHash", hash = ss.h.data.shortLog # We might use this for other handshake patterns/tokens -proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} = +proc mixKeyAndHash(ss: var SymmetricState; ikm: openArray[byte]) {.used.} = var temp_keys: array[3, ChaChaPolyKey] sha256.hkdf(ss.ck, ikm, [], temp_keys) @@ -173,7 +169,7 @@ proc mixKeyAndHash(ss: var SymmetricState; ikm: openarray[byte]) {.used.} = ss.mixHash(temp_keys[1]) ss.cs = CipherState(k: temp_keys[2]) -proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = +proc encryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] = # according to spec if key is empty leave plaintext if ss.cs.hasKey: result = ss.cs.encryptWithAd(ss.h.data, data) @@ -181,7 +177,7 @@ proc encryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = result = @data ss.mixHash(result) -proc decryptAndHash(ss: var SymmetricState, data: openarray[byte]): seq[byte] = +proc decryptAndHash(ss: var SymmetricState, data: openArray[byte]): seq[byte] = # according to spec if key is empty leave plaintext if ss.cs.hasKey: result = ss.cs.decryptWithAd(ss.h.data, data) @@ -202,13 +198,13 @@ template write_e: untyped = trace "noise write e" # Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key). hs.e = genKeyPair(p.rng[]) - msg &= hs.e.publicKey + msg.add hs.e.publicKey hs.ss.mixHash(hs.e.publicKey) template write_s: untyped = trace "noise write s" # Appends EncryptAndHash(s.public_key) to the buffer. - msg &= hs.ss.encryptAndHash(hs.s.publicKey) + msg.add hs.ss.encryptAndHash(hs.s.publicKey) template dh_ee: untyped = trace "noise dh ee" @@ -244,8 +240,8 @@ template read_e: untyped = raise newException(NoiseHandshakeError, "Noise E, expected more data") # Sets re (which must be empty) to the next DHLEN bytes from the message. Calls MixHash(re.public_key). - hs.re[0..Curve25519Key.high] = msg[0..Curve25519Key.high] - msg = msg[Curve25519Key.len..msg.high] + hs.re[0..Curve25519Key.high] = msg.toOpenArray(0, Curve25519Key.high) + msg.consume(Curve25519Key.len) hs.ss.mixHash(hs.re) template read_s: untyped = @@ -253,30 +249,33 @@ template read_s: untyped = # Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes otherwise. # Sets rs (which must be empty) to DecryptAndHash(temp). let - temp = + rsLen = if hs.ss.cs.hasKey: if msg.len < Curve25519Key.len + ChaChaPolyTag.len: raise newException(NoiseHandshakeError, "Noise S, expected more data") - msg[0..Curve25519Key.high + ChaChaPolyTag.len] + Curve25519Key.len + ChaChaPolyTag.len else: if msg.len < Curve25519Key.len: raise newException(NoiseHandshakeError, "Noise S, expected more data") - msg[0..Curve25519Key.high] - msg = msg[temp.len..msg.high] - let plain = hs.ss.decryptAndHash(temp) - hs.rs[0..Curve25519Key.high] = plain + Curve25519Key.len + hs.rs[0..Curve25519Key.high] = + hs.ss.decryptAndHash(msg.toOpenArray(0, rsLen - 1)) + + msg.consume(rsLen) proc receiveHSMessage(sconn: Connection): Future[seq[byte]] {.async.} = var besize: array[2, byte] await sconn.readExactly(addr besize[0], besize.len) let size = uint16.fromBytesBE(besize).int trace "receiveHSMessage", size + if size == 0: + return + var buffer = newSeq[byte](size) - if buffer.len > 0: - await sconn.readExactly(addr buffer[0], buffer.len) + await sconn.readExactly(addr buffer[0], buffer.len) return buffer -proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} = +proc sendHSMessage(sconn: Connection; buf: openArray[byte]): Future[void] = var lesize = buf.len.uint16 besize = lesize.toBytesBE @@ -284,97 +283,106 @@ proc sendHSMessage(sconn: Connection; buf: seq[byte]) {.async.} = trace "sendHSMessage", size = lesize outbuf &= besize outbuf &= buf - await sconn.write(outbuf) + sconn.write(outbuf) -proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} = +proc handshakeXXOutbound( + p: Noise, conn: Connection, + p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} = const initiator = true - var hs = HandshakeState.init() - p2psecret = p2pProof.buffer - hs.ss.mixHash(p.commonPrologue) - hs.s = p.noiseKeys + try: - # -> e - var msg: seq[byte] + hs.ss.mixHash(p.commonPrologue) + hs.s = p.noiseKeys - write_e() + # -> e + var msg: StreamSeq - # IK might use this btw! - msg &= hs.ss.encryptAndHash(@[]) + write_e() - await conn.sendHSMessage(msg) + # IK might use this btw! + msg.add hs.ss.encryptAndHash([]) - # <- e, ee, s, es + await conn.sendHSMessage(msg.data) - msg = await conn.receiveHSMessage() + # <- e, ee, s, es - read_e() - dh_ee() - read_s() - dh_es() + msg.assign(await conn.receiveHSMessage()) - let remoteP2psecret = hs.ss.decryptAndHash(msg) + read_e() + dh_ee() + read_s() + dh_es() - # -> s, se + let remoteP2psecret = hs.ss.decryptAndHash(msg.data) + msg.clear() - msg.setLen(0) + # -> s, se - write_s() - dh_se() + write_s() + dh_se() - # last payload must follow the ecrypted way of sending - msg &= hs.ss.encryptAndHash(p2psecret) + # last payload must follow the encrypted way of sending + msg.add hs.ss.encryptAndHash(p2psecret) - await conn.sendHSMessage(msg) + await conn.sendHSMessage(msg.data) - let (cs1, cs2) = hs.ss.split() - return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + let (cs1, cs2) = hs.ss.split() + return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + finally: + burnMem(hs) -proc handshakeXXInbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Future[HandshakeResult] {.async.} = +proc handshakeXXInbound( + p: Noise, conn: Connection, + p2pSecret: seq[byte]): Future[HandshakeResult] {.async.} = const initiator = false var hs = HandshakeState.init() - p2psecret = p2pProof.buffer - hs.ss.mixHash(p.commonPrologue) - hs.s = p.noiseKeys + try: + hs.ss.mixHash(p.commonPrologue) + hs.s = p.noiseKeys - # -> e + # -> e - var msg = await conn.receiveHSMessage() + var msg: StreamSeq + msg.add(await conn.receiveHSMessage()) - read_e() + read_e() - # we might use this early data one day, keeping it here for clarity - let earlyData {.used.} = hs.ss.decryptAndHash(msg) + # we might use this early data one day, keeping it here for clarity + let earlyData {.used.} = hs.ss.decryptAndHash(msg.data) - # <- e, ee, s, es + # <- e, ee, s, es - msg.setLen(0) + msg.consume(msg.len) - write_e() - dh_ee() - write_s() - dh_es() + write_e() + dh_ee() + write_s() + dh_es() - msg &= hs.ss.encryptAndHash(p2psecret) + msg.add hs.ss.encryptAndHash(p2psecret) - await conn.sendHSMessage(msg) + await conn.sendHSMessage(msg.data) + msg.clear() - # -> s, se + # -> s, se - msg = await conn.receiveHSMessage() + msg.add(await conn.receiveHSMessage()) - read_s() - dh_se() + read_s() + dh_se() - let remoteP2psecret = hs.ss.decryptAndHash(msg) - - let (cs1, cs2) = hs.ss.split() - return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + let + remoteP2psecret = hs.ss.decryptAndHash(msg.data) + (cs1, cs2) = hs.ss.split() + return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs) + finally: + burnMem(hs) method readMessage*(sconn: NoiseConnection): Future[seq[byte]] {.async.} = while true: # Discard 0-length payloads @@ -399,7 +407,8 @@ method write*(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async. while left > 0: let chunkSize = if left > MaxPlainSize: MaxPlainSize else: left - cipher = sconn.writeCs.encryptWithAd([], message.toOpenArray(offset, offset + chunkSize - 1)) + cipher = sconn.writeCs.encryptWithAd( + [], message.toOpenArray(offset, offset + chunkSize - 1)) left = left - chunkSize offset = offset + chunkSize var @@ -421,65 +430,75 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon var libp2pProof = initProtoBuffer() - libp2pProof.write(initProtoField(1, p.localPublicKey.getBytes.tryGet())) + libp2pProof.write(initProtoField(1, p.localPublicKey)) libp2pProof.write(initProtoField(2, signedPayload.getBytes())) # data field also there but not used! libp2pProof.finish() - let handshakeRes = + var handshakeRes = if initiator: - await handshakeXXOutbound(p, conn, libp2pProof) + await handshakeXXOutbound(p, conn, libp2pProof.buffer) else: - await handshakeXXInbound(p, conn, libp2pProof) + await handshakeXXInbound(p, conn, libp2pProof.buffer) - var - remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret) - remotePubKey: PublicKey - remotePubKeyBytes: seq[byte] - remoteSig: Signature - remoteSigBytes: seq[byte] + var secure = try: + var + remoteProof = initProtoBuffer(handshakeRes.remoteP2psecret) + remotePubKey: PublicKey + remotePubKeyBytes: seq[byte] + remoteSig: Signature + remoteSigBytes: seq[byte] - if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: - raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: - raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: + raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: + raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if not remotePubKey.init(remotePubKeyBytes): - raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if not remoteSig.init(remoteSigBytes): - raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if not remotePubKey.init(remotePubKeyBytes): + raise newException(NoiseHandshakeError, "Failed to decode remote public key. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") + if not remoteSig.init(remoteSigBytes): + raise newException(NoiseHandshakeError, "Failed to decode remote signature. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes - if not remoteSig.verify(verifyPayload, remotePubKey): - raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") - else: - trace "Remote signature verified", peer = $conn + let verifyPayload = PayloadString.toBytes & handshakeRes.rs.getBytes + if not remoteSig.verify(verifyPayload, remotePubKey): + raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.") + else: + trace "Remote signature verified", peer = $conn - if initiator and not isNil(conn.peerInfo): - let pid = PeerID.init(remotePubKey) - if not conn.peerInfo.peerId.validate(): - raise newException(NoiseHandshakeError, "Failed to validate peerId.") - if pid.isErr or pid.get() != conn.peerInfo.peerId: - var - failedKey: PublicKey - discard extractPublicKey(conn.peerInfo.peerId, failedKey) - debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey - raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) + if initiator and not isNil(conn.peerInfo): + let pid = PeerID.init(remotePubKey) + if not conn.peerInfo.peerId.validate(): + raise newException(NoiseHandshakeError, "Failed to validate peerId.") + if pid.isErr or pid.get() != conn.peerInfo.peerId: + var + failedKey: PublicKey + discard extractPublicKey(conn.peerInfo.peerId, failedKey) + debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey + raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) - var secure = NoiseConnection.init(conn, - PeerInfo.init(remotePubKey), - conn.observedAddr) - if initiator: - secure.readCs = handshakeRes.cs2 - secure.writeCs = handshakeRes.cs1 - else: - secure.readCs = handshakeRes.cs1 - secure.writeCs = handshakeRes.cs2 + var tmp = NoiseConnection.init( + conn, PeerInfo.init(remotePubKey), conn.observedAddr) + + if initiator: + tmp.readCs = handshakeRes.cs2 + tmp.writeCs = handshakeRes.cs1 + else: + tmp.readCs = handshakeRes.cs1 + tmp.writeCs = handshakeRes.cs2 + tmp + finally: + burnMem(handshakeRes) trace "Noise handshake completed!", initiator, peer = $secure.peerInfo return secure +method close*(s: NoiseConnection) {.async.} = + await procCall SecureConn(s).close() + + burnMem(s.readCs) + burnMem(s.writeCs) + method init*(p: Noise) {.gcsafe.} = procCall Secure(p).init() p.codec = NoiseCodec @@ -491,7 +510,7 @@ proc newNoise*( rng: rng, outgoing: outgoing, localPrivateKey: privateKey, - localPublicKey: privateKey.getKey().tryGet(), + localPublicKey: privateKey.getKey().tryGet().getBytes().tryGet(), noiseKeys: genKeyPair(rng[]), commonPrologue: commonPrologue, ) diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 235de6e33..7f05fe49e 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -59,10 +59,9 @@ proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} = var sconn = await s.handshake(conn, initiator) - - conn.closeEvent.wait() - .addCallback do(udata: pointer = nil): - if not(isNil(sconn)): + if not isNil(sconn): + conn.closeEvent.wait() + .addCallback do(udata: pointer = nil): asyncCheck sconn.close() return sconn diff --git a/libp2p/stream/streamseq.nim b/libp2p/stream/streamseq.nim index bbb44aeba..30f62ad7f 100644 --- a/libp2p/stream/streamseq.nim +++ b/libp2p/stream/streamseq.nim @@ -61,6 +61,11 @@ template data*(v: StreamSeq): openArray[byte] = # TODO a double-hash comment here breaks compile (!) v.buf.toOpenArray(v.rpos, v.wpos - 1) +template toOpenArray*(v: StreamSeq, b, e: int): openArray[byte] = + # Data that is ready to be consumed + # TODO a double-hash comment here breaks compile (!) + v.buf.toOpenArray(v.rpos + b, v.rpos + e - b) + func consume*(v: var StreamSeq, n: int) = ## Mark `n` bytes that were returned via `data` as consumed v.rpos += n @@ -71,3 +76,10 @@ func consumeTo*(v: var StreamSeq, buf: var openArray[byte]): int = copyMem(addr buf[0], addr v.buf[v.rpos], bytes) v.consume(bytes) bytes + +func clear*(v: var StreamSeq) = + v.consume(v.len) + +func assign*(v: var StreamSeq, buf: openArray[byte]) = + v.clear() + v.add(buf) From 9a3684c22159ef78c7ba746ce38204e0d4589f43 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Thu, 9 Jul 2020 10:59:09 +0200 Subject: [PATCH 08/23] init from concrete key type (#252) --- libp2p/crypto/crypto.nim | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/libp2p/crypto/crypto.nim b/libp2p/crypto/crypto.nim index 66fe81c9e..78b7dd519 100644 --- a/libp2p/crypto/crypto.nim +++ b/libp2p/crypto/crypto.nim @@ -374,6 +374,24 @@ proc init*(t: typedesc[PrivateKey], data: string): CryptoResult[PrivateKey] = except ValueError: err(KeyError) +proc init*(t: typedesc[PrivateKey], key: rsa.RsaPrivateKey): PrivateKey = + PrivateKey(scheme: RSA, rsakey: key) +proc init*(t: typedesc[PrivateKey], key: EdPrivateKey): PrivateKey = + PrivateKey(scheme: Ed25519, edkey: key) +proc init*(t: typedesc[PrivateKey], key: SkPrivateKey): PrivateKey = + PrivateKey(scheme: Secp256k1, skkey: key) +proc init*(t: typedesc[PrivateKey], key: ecnist.EcPrivateKey): PrivateKey = + PrivateKey(scheme: ECDSA, eckey: key) + +proc init*(t: typedesc[PublicKey], key: rsa.RsaPublicKey): PublicKey = + PublicKey(scheme: RSA, rsakey: key) +proc init*(t: typedesc[PublicKey], key: EdPublicKey): PublicKey = + PublicKey(scheme: Ed25519, edkey: key) +proc init*(t: typedesc[PublicKey], key: SkPublicKey): PublicKey = + PublicKey(scheme: Secp256k1, skkey: key) +proc init*(t: typedesc[PublicKey], key: ecnist.EcPublicKey): PublicKey = + PublicKey(scheme: ECDSA, eckey: key) + proc init*(t: typedesc[PublicKey], data: string): CryptoResult[PublicKey] = ## Create new public key from libp2p's protobuf serialized hexadecimal string ## form. From c720e042fc2b576bee0070bb94eed0f6a573931e Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Thu, 9 Jul 2020 19:16:46 +0200 Subject: [PATCH 09/23] clean up mesh handling logic (#260) * gossipsub is a function of subscription messages only * graft/prune work with mesh, get filled up from gossipsub * fix race conditions with await * fix exception unsafety when grafting/pruning * fix allowing up to DHi peers in mesh on incoming graft * fix metrics in several places --- libp2p/protocols/pubsub/gossipsub.nim | 279 +++++++++++++------------ libp2p/protocols/pubsub/pubsubpeer.nim | 22 +- tests/pubsub/testgossipinternal.nim | 4 +- 3 files changed, 163 insertions(+), 142 deletions(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 0fa25d25e..9f56ee03a 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -45,9 +45,9 @@ const GossipSubFanoutTTL* = 1.minutes type GossipSub* = ref object of FloodSub - mesh*: Table[string, HashSet[string]] # meshes - topic to peer - fanout*: Table[string, HashSet[string]] # fanout - topic to peer - gossipsub*: Table[string, HashSet[string]] # topic to peer map of all gossipsub peers + mesh*: Table[string, HashSet[string]] # peers that we send messages to when we are subscribed to the topic + fanout*: Table[string, HashSet[string]] # peers that we send messages to when we're not subscribed to the topic + gossipsub*: Table[string, HashSet[string]] # peers that are subscribed to a topic lastFanoutPubSub*: Table[string, Moment] # last publish time for fanout topics gossip*: Table[string, seq[ControlIHave]] # pending gossip control*: Table[string, ControlMessage] # pending control messages @@ -68,6 +68,31 @@ declareGauge(libp2p_gossipsub_peers_per_topic_gossipsub, "gossipsub peers per topic in gossipsub", labels = ["topic"]) +func addPeer( + table: var Table[string, HashSet[string]], topic: string, + peerId: string): bool = + # returns true if the peer was added, false if it was already in the collection + not table.mgetOrPut(topic, initHashSet[string]()).containsOrIncl(peerId) + +func removePeer( + table: var Table[string, HashSet[string]], topic, peerId: string) = + table.withValue(topic, peers): + peers[].excl(peerId) + if peers[].len == 0: + table.del(topic) + +func hasPeer(table: Table[string, HashSet[string]], topic, peerId: string): bool = + (topic in table) and (peerId in table[topic]) + +func peers(table: Table[string, HashSet[string]], topic: string): int = + if topic in table: + table[topic].len + else: + 0 + +func getPeers(table: Table[string, HashSet[string]], topic: string): HashSet[string] = + table.getOrDefault(topic, initHashSet[string]()) + method init*(g: GossipSub) = proc handler(conn: Connection, proto: string) {.async.} = ## main protocol handler that gets triggered on every @@ -83,119 +108,102 @@ method init*(g: GossipSub) = proc replenishFanout(g: GossipSub, topic: string) = ## get fanout peers for a topic trace "about to replenish fanout" - if topic notin g.fanout: - g.fanout[topic] = initHashSet[string]() - if g.fanout.getOrDefault(topic).len < GossipSubDLo: - trace "replenishing fanout", peers = g.fanout.getOrDefault(topic).len - if topic in toSeq(g.gossipsub.keys): - for p in g.gossipsub.getOrDefault(topic): - if not g.fanout[topic].containsOrIncl(p): - if g.fanout.getOrDefault(topic).len == GossipSubD: + if g.fanout.peers(topic) < GossipSubDLo: + trace "replenishing fanout", peers = g.fanout.peers(topic) + if topic in g.gossipsub: + for peerId in g.gossipsub[topic]: + if g.fanout.addPeer(topic, peerId): + if g.fanout.peers(topic) == GossipSubD: break libp2p_gossipsub_peers_per_topic_fanout - .set(g.fanout.getOrDefault(topic).len.int64, - labelValues = [topic]) + .set(g.fanout.peers(topic).int64, labelValues = [topic]) - trace "fanout replenished with peers", peers = g.fanout.getOrDefault(topic).len - -template moveToMeshHelper(g: GossipSub, - topic: string, - table: Table[string, HashSet[string]]) = - ## move peers from `table` into `mesh` - ## - var peerIds = toSeq(table.getOrDefault(topic)) - - logScope: - topic = topic - meshPeers = g.mesh.getOrDefault(topic).len - peers = peerIds.len - - shuffle(peerIds) - for id in peerIds: - if g.mesh.getOrDefault(topic).len > GossipSubD: - break - - trace "gathering peers for mesh" - if topic notin table: - continue - - trace "getting peers", topic, - peers = peerIds.len - - table[topic].excl(id) # always exclude - if id in g.mesh[topic]: - continue # we already have this peer in the mesh, try again - - if id in g.peers: - let p = g.peers[id] - if p.connected: - # send a graft message to the peer - await p.sendGraft(@[topic]) - g.mesh[topic].incl(id) - trace "got peer", peer = id + trace "fanout replenished with peers", peers = g.fanout.peers(topic) proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = - try: - trace "about to rebalance mesh" - # create a mesh topic that we're subscribing to - if topic notin g.mesh: - g.mesh[topic] = initHashSet[string]() + trace "about to rebalance mesh" + # create a mesh topic that we're subscribing to - if g.mesh.getOrDefault(topic).len < GossipSubDlo: - trace "replenishing mesh", topic - # replenish the mesh if we're below GossipSubDlo + var + grafts, prunes: seq[PubSubPeer] - # move fanout nodes first - g.moveToMeshHelper(topic, g.fanout) + if g.mesh.peers(topic) < GossipSubDlo: + trace "replenishing mesh", topic, peers = g.mesh.peers(topic) + # replenish the mesh if we're below GossipSubDlo + var newPeers = toSeq( + g.gossipsub.getOrDefault(topic, initHashSet[string]()) - + g.mesh.getOrDefault(topic, initHashSet[string]()) + ) - # move gossipsub nodes second - g.moveToMeshHelper(topic, g.gossipsub) + logScope: + topic = topic + meshPeers = g.mesh.peers(topic) + newPeers = newPeers.len - if g.mesh.getOrDefault(topic).len > GossipSubDhi: - # prune peers if we've gone over - var mesh = toSeq(g.mesh.getOrDefault(topic)) - shuffle(mesh) + shuffle(newPeers) - trace "about to prune mesh", mesh = mesh.len - for id in mesh: - if g.mesh.getOrDefault(topic).len <= GossipSubD: - break + trace "getting peers", topic, peers = peerIds.len - trace "pruning peers", peers = g.mesh[topic].len - let p = g.peers[id] + for id in newPeers: + if g.mesh.peers(topic) >= GossipSubD: + break + + let p = g.peers.getOrDefault(id) + if p != nil: # send a graft message to the peer - await p.sendPrune(@[topic]) - g.mesh[topic].excl(id) - if topic in g.gossipsub: - g.gossipsub[topic].incl(id) + grafts.add p + discard g.mesh.addPeer(topic, id) + trace "got peer", peer = id + else: + # Peer should have been removed from mesh also! + warn "Unknown peer in mesh", peer = id - libp2p_gossipsub_peers_per_topic_gossipsub - .set(g.gossipsub.getOrDefault(topic).len.int64, - labelValues = [topic]) + if g.mesh.peers(topic) > GossipSubDhi: + # prune peers if we've gone over + var mesh = toSeq(g.mesh[topic]) + shuffle(mesh) - libp2p_gossipsub_peers_per_topic_fanout - .set(g.fanout.getOrDefault(topic).len.int64, - labelValues = [topic]) + trace "about to prune mesh", mesh = mesh.len + for id in mesh: + if g.mesh.peers(topic) <= GossipSubD: + break - libp2p_gossipsub_peers_per_topic_mesh - .set(g.mesh.getOrDefault(topic).len.int64, - labelValues = [topic]) + trace "pruning peers", peers = g.mesh.peers(topic) + # send a graft message to the peer + g.mesh.removePeer(topic, id) - trace "mesh balanced, got peers", peers = g.mesh.getOrDefault(topic).len, - topicId = topic - except CancelledError as exc: - raise exc - except CatchableError as exc: - warn "exception occurred re-balancing mesh", exc = exc.msg + let p = g.peers.getOrDefault(id) + if p != nil: + prunes.add(p) -proc dropFanoutPeers(g: GossipSub) {.async.} = + libp2p_gossipsub_peers_per_topic_gossipsub + .set(g.gossipsub.peers(topic).int64, labelValues = [topic]) + + libp2p_gossipsub_peers_per_topic_fanout + .set(g.fanout.peers(topic).int64, labelValues = [topic]) + + libp2p_gossipsub_peers_per_topic_mesh + .set(g.mesh.peers(topic).int64, labelValues = [topic]) + + # Send changes to peers after table updates to avoid stale state + for p in grafts: + await p.sendGraft(@[topic]) + for p in prunes: + await p.sendPrune(@[topic]) + + trace "mesh balanced, got peers", peers = g.mesh.peers(topic), + topicId = topic + +proc dropFanoutPeers(g: GossipSub) = # drop peers that we haven't published to in # GossipSubFanoutTTL seconds var dropping = newSeq[string]() + let now = Moment.now() + for topic, val in g.lastFanoutPubSub: - if Moment.now > val: + if now > val: dropping.add(topic) g.fanout.del(topic) trace "dropping fanout topic", topic @@ -204,7 +212,7 @@ proc dropFanoutPeers(g: GossipSub) {.async.} = g.lastFanoutPubSub.del(topic) libp2p_gossipsub_peers_per_topic_fanout - .set(g.fanout.getOrDefault(topic).len.int64, labelValues = [topic]) + .set(g.fanout.peers(topic).int64, labelValues = [topic]) proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} = ## gossip iHave messages to peers @@ -257,7 +265,7 @@ proc heartbeat(g: GossipSub) {.async.} = for t in toSeq(g.topics.keys): await g.rebalanceMesh(t) - await g.dropFanoutPeers() + g.dropFanoutPeers() # replenish known topics to the fanout for t in toSeq(g.fanout.keys): @@ -281,27 +289,23 @@ proc heartbeat(g: GossipSub) {.async.} = method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = ## handle peer disconnects procCall FloodSub(g).handleDisconnect(peer) - for t in toSeq(g.gossipsub.keys): - if t in g.gossipsub: - g.gossipsub[t].excl(peer.id) + g.gossipsub.removePeer(t, peer.id) libp2p_gossipsub_peers_per_topic_gossipsub - .set(g.gossipsub.getOrDefault(t).len.int64, labelValues = [t]) + .set(g.gossipsub.peers(t).int64, labelValues = [t]) for t in toSeq(g.mesh.keys): - if t in g.mesh: - g.mesh[t].excl(peer.id) + g.mesh.removePeer(t, peer.id) libp2p_gossipsub_peers_per_topic_mesh - .set(g.mesh[t].len.int64, labelValues = [t]) + .set(g.mesh.peers(t).int64, labelValues = [t]) for t in toSeq(g.fanout.keys): - if t in g.fanout: - g.fanout[t].excl(peer.id) + g.fanout.removePeer(t, peer.id) libp2p_gossipsub_peers_per_topic_fanout - .set(g.fanout[t].len.int64, labelValues = [t]) + .set(g.fanout.peers(t).int64, labelValues = [t]) method subscribePeer*(p: GossipSub, conn: Connection) = @@ -314,26 +318,26 @@ method subscribeTopic*(g: GossipSub, peerId: string) {.gcsafe, async.} = await procCall PubSub(g).subscribeTopic(topic, subscribe, peerId) - if topic notin g.gossipsub: - g.gossipsub[topic] = initHashSet[string]() - if subscribe: trace "adding subscription for topic", peer = peerId, name = topic # subscribe remote peer to the topic - g.gossipsub[topic].incl(peerId) + discard g.gossipsub.addPeer(topic, peerId) else: trace "removing subscription for topic", peer = peerId, name = topic # unsubscribe remote peer from the topic - g.gossipsub[topic].excl(peerId) - if peerId in g.mesh.getOrDefault(topic): - g.mesh[topic].excl(peerId) - if peerId in g.fanout.getOrDefault(topic): - g.fanout[topic].excl(peerId) + g.gossipsub.removePeer(topic, peerId) + g.mesh.removePeer(topic, peerId) + g.fanout.removePeer(topic, peerId) + + libp2p_gossipsub_peers_per_topic_mesh + .set(g.mesh.peers(topic).int64, labelValues = [topic]) + libp2p_gossipsub_peers_per_topic_fanout + .set(g.fanout.peers(topic).int64, labelValues = [topic]) libp2p_gossipsub_peers_per_topic_gossipsub - .set(g.gossipsub[topic].len.int64, labelValues = [topic]) + .set(g.gossipsub.peers(topic).int64, labelValues = [topic]) - trace "gossip peers", peers = g.gossipsub[topic].len, topic + trace "gossip peers", peers = g.gossipsub.peers(topic), topic # also rebalance current topic if we are subbed to if topic in g.topics: @@ -343,34 +347,39 @@ proc handleGraft(g: GossipSub, peer: PubSubPeer, grafts: seq[ControlGraft], respControl: var ControlMessage) = + let peerId = peer.id for graft in grafts: - trace "processing graft message", peer = peer.id, - topicID = graft.topicID + let topic = graft.topicID + trace "processing graft message", topic, peerId - if graft.topicID in g.topics: - if g.mesh.len < GossipSubD: - g.mesh[graft.topicID].incl(peer.id) + # If they send us a graft before they send us a subscribe, what should + # we do? For now, we add them to mesh but don't add them to gossipsub. + + if topic in g.topics: + if g.mesh.peers(topic) < GossipSubDHi: + # In the spec, there's no mention of DHi here, but implicitly, a + # peer will be removed from the mesh on next rebalance, so we don't want + # this peer to push someone else out + if g.mesh.addPeer(topic, peerId): + g.fanout.removePeer(topic, peer.id) + else: + trace "Peer already in mesh", topic, peerId else: - g.gossipsub[graft.topicID].incl(peer.id) + respControl.prune.add(ControlPrune(topicID: topic)) else: - respControl.prune.add(ControlPrune(topicID: graft.topicID)) + respControl.prune.add(ControlPrune(topicID: topic)) libp2p_gossipsub_peers_per_topic_mesh - .set(g.mesh[graft.topicID].len.int64, labelValues = [graft.topicID]) - - libp2p_gossipsub_peers_per_topic_gossipsub - .set(g.gossipsub[graft.topicID].len.int64, labelValues = [graft.topicID]) + .set(g.mesh.peers(topic).int64, labelValues = [topic]) proc handlePrune(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) = for prune in prunes: trace "processing prune message", peer = peer.id, topicID = prune.topicID - if prune.topicID in g.mesh: - g.mesh[prune.topicID].excl(peer.id) - g.gossipsub[prune.topicID].incl(peer.id) - libp2p_gossipsub_peers_per_topic_mesh - .set(g.mesh[prune.topicID].len.int64, labelValues = [prune.topicID]) + g.mesh.removePeer(prune.topicID, peer.id) + libp2p_gossipsub_peers_per_topic_mesh + .set(g.mesh.peers(prune.topicID).int64, labelValues = [prune.topicID]) proc handleIHave(g: GossipSub, peer: PubSubPeer, @@ -485,9 +494,11 @@ method unsubscribe*(g: GossipSub, if topic in g.mesh: let peers = g.mesh.getOrDefault(topic) g.mesh.del(topic) + for id in peers: - let p = g.peers[id] - await p.sendPrune(@[topic]) + let p = g.peers.getOrDefault(id) + if p != nil: + await p.sendPrune(@[topic]) method publish*(g: GossipSub, topic: string, diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 1bda8adde..212404dab 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -163,14 +163,24 @@ proc sendMsg*(p: PubSubPeer, p.send(@[RPCMsg(messages: @[Message.init(p.peerInfo, data, topic, sign)])]) proc sendGraft*(p: PubSubPeer, topics: seq[string]) {.async.} = - for topic in topics: - trace "sending graft msg to peer", peer = p.id, topicID = topic - await p.send(@[RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))]) + try: + for topic in topics: + trace "sending graft msg to peer", peer = p.id, topicID = topic + await p.send(@[RPCMsg(control: some(ControlMessage(graft: @[ControlGraft(topicID: topic)])))]) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "Could not send graft", msg = exc.msg proc sendPrune*(p: PubSubPeer, topics: seq[string]) {.async.} = - for topic in topics: - trace "sending prune msg to peer", peer = p.id, topicID = topic - await p.send(@[RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))]) + try: + for topic in topics: + trace "sending prune msg to peer", peer = p.id, topicID = topic + await p.send(@[RPCMsg(control: some(ControlMessage(prune: @[ControlPrune(topicID: topic)])))]) + except CancelledError as exc: + raise exc + except CatchableError as exc: + trace "Could not send prune", msg = exc.msg proc `$`*(p: PubSubPeer): string = p.id diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index a6aa40e1e..c555d703c 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -137,7 +137,7 @@ suite "GossipSub internal": check gossipSub.fanout[topic].len == GossipSubD - await gossipSub.dropFanoutPeers() + gossipSub.dropFanoutPeers() check topic notin gossipSub.fanout await allFuturesThrowing(conns.mapIt(it.close())) @@ -176,7 +176,7 @@ suite "GossipSub internal": check gossipSub.fanout[topic1].len == GossipSubD check gossipSub.fanout[topic2].len == GossipSubD - await gossipSub.dropFanoutPeers() + gossipSub.dropFanoutPeers() check topic1 notin gossipSub.fanout check topic2 in gossipSub.fanout From 4c815d75e7af62337f93bcbdd5ea5367c6f57207 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Thu, 9 Jul 2020 14:21:47 -0600 Subject: [PATCH 10/23] More gossip cleanup (#257) * more cleanup * correct pubsub peer count * close the stream first * handle cancelation * fix tests * fix fanout ttl * merging master * remove `withLock` as it conflicts with stdlib * fix trace build Co-authored-by: Giovanni Petrantoni --- libp2p/protocols/pubsub/gossipsub.nim | 24 +++++++-------- libp2p/protocols/pubsub/pubsub.nim | 42 +++++++++++++++----------- libp2p/protocols/pubsub/pubsubpeer.nim | 1 - libp2p/protocols/secure/secure.nim | 4 +-- libp2p/stream/chronosstream.nim | 5 --- libp2p/switch.nim | 3 +- tests/pubsub/testgossipinternal.nim | 3 +- tests/pubsub/testgossipsub.nim | 2 -- 8 files changed, 42 insertions(+), 42 deletions(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 9f56ee03a..c1864089f 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -38,7 +38,7 @@ const GossipSubHistoryGossip* = 3 # heartbeat interval const GossipSubHeartbeatInitialDelay* = 100.millis -const GossipSubHeartbeatInterval* = 1.seconds +const GossipSubHeartbeatInterval* = 5.seconds # TODO: per the spec it should be 1 second # fanout ttl const GossipSubFanoutTTL* = 1.minutes @@ -144,7 +144,7 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = shuffle(newPeers) - trace "getting peers", topic, peers = peerIds.len + trace "getting peers", topic, peers = newPeers.len for id in newPeers: if g.mesh.peers(topic) >= GossipSubD: @@ -208,9 +208,6 @@ proc dropFanoutPeers(g: GossipSub) = g.fanout.del(topic) trace "dropping fanout topic", topic - for topic in dropping: - g.lastFanoutPubSub.del(topic) - libp2p_gossipsub_peers_per_topic_fanout .set(g.fanout.peers(topic).int64, labelValues = [topic]) @@ -245,10 +242,6 @@ proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} = trace "got gossip peers", peers = result.len break - if allPeers.len == 0: - trace "no peers for topic, skipping", topicID = topic - break - if id in gossipPeers: continue @@ -284,7 +277,7 @@ proc heartbeat(g: GossipSub) {.async.} = except CatchableError as exc: trace "exception ocurred in gossipsub heartbeat", exc = exc.msg - await sleepAsync(5.seconds) + await sleepAsync(GossipSubHeartbeatInterval) method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = ## handle peer disconnects @@ -308,7 +301,7 @@ method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = .set(g.fanout.peers(t).int64, labelValues = [t]) method subscribePeer*(p: GossipSub, - conn: Connection) = + conn: Connection) = procCall PubSub(p).subscribePeer(conn) asyncCheck p.handleConn(conn, GossipSubCodec) @@ -316,7 +309,7 @@ method subscribeTopic*(g: GossipSub, topic: string, subscribe: bool, peerId: string) {.gcsafe, async.} = - await procCall PubSub(g).subscribeTopic(topic, subscribe, peerId) + await procCall FloodSub(g).subscribeTopic(topic, subscribe, peerId) if subscribe: trace "adding subscription for topic", peer = peerId, name = topic @@ -521,6 +514,13 @@ method publish*(g: GossipSub, g.replenishFanout(topic) peers = g.fanout.getOrDefault(topic) + # even if we couldn't publish, + # we still attempted to publish + # on the topic, so it makes sense + # to update the last topic publish + # time + g.lastFanoutPubSub[topic] = Moment.fromNow(GossipSubFanoutTTL) + let msg = Message.init(g.peerInfo, data, topic, g.sign) msgId = g.msgIdProvider(msg) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index dbfe012ba..1a97ffab1 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -57,7 +57,7 @@ type cleanupLock: AsyncLock validators*: Table[string, HashSet[ValidatorHandler]] observers: ref seq[PubSubObserver] # ref as in smart_ptr - msgIdProvider*: MsgIdProvider # Turn message into message id (not nil) + msgIdProvider*: MsgIdProvider # Turn message into message id (not nil) method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} = ## handle peer disconnects @@ -65,10 +65,10 @@ method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} = if not isNil(peer.peerInfo) and peer.id in p.peers: trace "deleting peer", peer = peer.id p.peers.del(peer.id) + trace "peer disconnected", peer = peer.id # metrics libp2p_pubsub_peers.set(p.peers.len.int64) - trace "peer disconnected", peer = peer.id proc sendSubs*(p: PubSub, peer: PubSubPeer, @@ -120,9 +120,9 @@ method rpcHandler*(p: PubSub, trace "about to subscribe to topic", topicId = s.topic await p.subscribeTopic(s.topic, s.subscribe, peer.id) -proc getPeer(p: PubSub, - peerInfo: PeerInfo, - proto: string): PubSubPeer = +proc getOrCreatePeer(p: PubSub, + peerInfo: PeerInfo, + proto: string): PubSubPeer = if peerInfo.id in p.peers: return p.peers[peerInfo.id] @@ -132,7 +132,6 @@ proc getPeer(p: PubSub, p.peers[peer.id] = peer peer.observers = p.observers - libp2p_pubsub_peers.set(p.peers.len.int64) return peer method handleConn*(p: PubSub, @@ -158,7 +157,7 @@ method handleConn*(p: PubSub, # call pubsub rpc handler await p.rpcHandler(peer, msgs) - let peer = p.getPeer(conn.peerInfo, proto) + let peer = p.getOrCreatePeer(conn.peerInfo, proto) let topics = toSeq(p.topics.keys) if topics.len > 0: await p.sendSubs(peer, topics, true) @@ -177,23 +176,27 @@ method handleConn*(p: PubSub, method subscribePeer*(p: PubSub, conn: Connection) {.base.} = if not(isNil(conn)): - let peer = p.getPeer(conn.peerInfo, p.codec) + let peer = p.getOrCreatePeer(conn.peerInfo, p.codec) trace "subscribing to peer", peerId = conn.peerInfo.id if not peer.connected: peer.conn = conn method unsubscribePeer*(p: PubSub, peerInfo: PeerInfo) {.base, async.} = - let peer = p.getPeer(peerInfo, p.codec) - trace "unsubscribing from peer", peerId = $peerInfo - if not(isNil(peer.conn)): - await peer.conn.close() + if peerInfo.id in p.peers: + let peer = p.peers[peerInfo.id] - p.handleDisconnect(peer) + trace "unsubscribing from peer", peerId = $peerInfo + if not(isNil(peer.conn)): + await peer.conn.close() -proc connected*(p: PubSub, peer: PeerInfo): bool = - let peer = p.getPeer(peer, p.codec) - if not(isNil(peer)): - return peer.connected + p.handleDisconnect(peer) + +proc connected*(p: PubSub, peerInfo: PeerInfo): bool = + if peerInfo.id in p.peers: + let peer = p.peers[peerInfo.id] + + if not(isNil(peer)): + return peer.connected method unsubscribe*(p: PubSub, topics: seq[TopicPair]) {.base, async.} = @@ -205,6 +208,11 @@ method unsubscribe*(p: PubSub, if h == t.handler: p.topics[t.topic].handler.del(i) + # make sure we delete the topic if + # no more handlers are left + if p.topics[t.topic].handler.len <= 0: + p.topics.del(t.topic) + method unsubscribe*(p: PubSub, topic: string, handler: TopicHandler): Future[void] {.base.} = diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 212404dab..db1312e83 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -36,7 +36,6 @@ type sendConn: Connection peerInfo*: PeerInfo handler*: RPCHandler - topics*: seq[string] sentRpcCache: TimedCache[string] # cache for already sent messages recvdRpcCache: TimedCache[string] # cache for already received messages onConnect*: AsyncEvent diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index 7f05fe49e..aeff8fa50 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -42,11 +42,11 @@ method initStream*(s: SecureConn) = procCall Connection(s).initStream() method close*(s: SecureConn) {.async.} = - await procCall Connection(s).close() - if not(isNil(s.stream)): await s.stream.close() + await procCall Connection(s).close() + method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = doAssert(false, "Not implemented!") diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 312ae237e..4c27ff2e5 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -75,15 +75,10 @@ method close*(s: ChronosStream) {.async.} = if not s.isClosed: trace "shutting down chronos stream", address = $s.client.remoteAddress(), oid = s.oid - - # TODO: the sequence here matters - # don't move it after the connections - # close bellow if not s.client.closed(): await s.client.closeWait() await procCall Connection(s).close() - except CancelledError as exc: raise exc except CatchableError as exc: diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 418b0ca79..015eb32a7 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -302,12 +302,11 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = conn.peerInfo.close() finally: await conn.close() + libp2p_peers.set(s.connections.len.int64) if lock.locked(): lock.release() - libp2p_peers.set(s.connections.len.int64) - proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = let connections = s.connections.getOrDefault(peer.id) for connHolder in connections: diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index c555d703c..2fb13d765 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -32,6 +32,7 @@ suite "GossipSub internal": gossipSub.mesh[topic] = initHashSet[string]() var conns = newSeq[Connection]() + gossipSub.gossipsub[topic] = initHashSet[string]() for i in 0..<15: let conn = newBufferStream(noop) conns &= conn @@ -60,6 +61,7 @@ suite "GossipSub internal": gossipSub.mesh[topic] = initHashSet[string]() gossipSub.topics[topic] = Topic() # has to be in topics to rebalance + gossipSub.gossipsub[topic] = initHashSet[string]() var conns = newSeq[Connection]() for i in 0..<15: let conn = newBufferStream(noop) @@ -99,7 +101,6 @@ suite "GossipSub internal": conn.peerInfo = peerInfo gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) gossipSub.peers[peerInfo.id].handler = handler - gossipSub.peers[peerInfo.id].topics &= topic gossipSub.gossipsub[topic].incl(peerInfo.id) check gossipSub.gossipsub[topic].len == 15 diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 572d4f617..757429500 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -86,7 +86,6 @@ suite "GossipSub": nodes[1].addValidator("foobar", validator) tryPublish await nodes[0].publish("foobar", "Hello!".toBytes()), 1 - result = (await validatorFut) and (await handlerFut) await allFuturesThrowing( nodes[0].stop(), @@ -142,7 +141,6 @@ suite "GossipSub": awaiters.add((await nodes[1].start())) await subscribeNodes(nodes) - await nodes[1].subscribe("foo", handler) await nodes[1].subscribe("bar", handler) From bec9a0658f612bcaec518021717462fd5f509ba8 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Thu, 9 Jul 2020 17:54:16 -0600 Subject: [PATCH 11/23] Cleanup rpc handler (#261) * more cleanup * fix tests * merging master * remove `withLock` as it conflicts with stdlib * wip * more fanout ttl Co-authored-by: Giovanni Petrantoni --- libp2p/protocols/pubsub/gossipsub.nim | 29 +++++++++---------- libp2p/protocols/pubsub/rpc/protobuf.nim | 36 +++++++++++++++--------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index c1864089f..32240bc75 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -199,13 +199,12 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = proc dropFanoutPeers(g: GossipSub) = # drop peers that we haven't published to in # GossipSubFanoutTTL seconds - var dropping = newSeq[string]() let now = Moment.now() - - for topic, val in g.lastFanoutPubSub: + for topic in toSeq(g.lastFanoutPubSub.keys): + let val = g.lastFanoutPubSub[topic] if now > val: - dropping.add(topic) g.fanout.del(topic) + g.lastFanoutPubSub.del(topic) trace "dropping fanout topic", topic libp2p_gossipsub_peers_per_topic_fanout @@ -338,8 +337,7 @@ method subscribeTopic*(g: GossipSub, proc handleGraft(g: GossipSub, peer: PubSubPeer, - grafts: seq[ControlGraft], - respControl: var ControlMessage) = + grafts: seq[ControlGraft]): seq[ControlPrune] = let peerId = peer.id for graft in grafts: let topic = graft.topicID @@ -358,9 +356,9 @@ proc handleGraft(g: GossipSub, else: trace "Peer already in mesh", topic, peerId else: - respControl.prune.add(ControlPrune(topicID: topic)) + result.add(ControlPrune(topicID: topic)) else: - respControl.prune.add(ControlPrune(topicID: topic)) + result.add(ControlPrune(topicID: topic)) libp2p_gossipsub_peers_per_topic_mesh .set(g.mesh.peers(topic).int64, labelValues = [topic]) @@ -459,18 +457,17 @@ method rpcHandler*(g: GossipSub, var respControl: ControlMessage if m.control.isSome: - var control: ControlMessage = m.control.get() - let iWant: ControlIWant = g.handleIHave(peer, control.ihave) - if iWant.messageIDs.len > 0: - respControl.iwant.add(iWant) - let messages: seq[Message] = g.handleIWant(peer, control.iwant) - - g.handleGraft(peer, control.graft, respControl) + let control = m.control.get() g.handlePrune(peer, control.prune) + respControl.iwant.add(g.handleIHave(peer, control.ihave)) + respControl.prune.add(g.handleGraft(peer, control.graft)) + if respControl.graft.len > 0 or respControl.prune.len > 0 or respControl.ihave.len > 0 or respControl.iwant.len > 0: - await peer.send(@[RPCMsg(control: some(respControl), messages: messages)]) + await peer.send( + @[RPCMsg(control: some(respControl), + messages: g.handleIWant(peer, control.iwant))]) method subscribe*(g: GossipSub, topic: string, diff --git a/libp2p/protocols/pubsub/rpc/protobuf.nim b/libp2p/protocols/pubsub/rpc/protobuf.nim index ce42275d0..922546b20 100644 --- a/libp2p/protocols/pubsub/rpc/protobuf.nim +++ b/libp2p/protocols/pubsub/rpc/protobuf.nim @@ -88,8 +88,9 @@ proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} = h.encodeIHave(ihave) # write messages to protobuf - ihave.finish() - pb.write(initProtoField(1, ihave)) + if ihave.buffer.len > 0: + ihave.finish() + pb.write(initProtoField(1, ihave)) if control.iwant.len > 0: var iwant = initProtoBuffer() @@ -97,8 +98,9 @@ proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} = w.encodeIWant(iwant) # write messages to protobuf - iwant.finish() - pb.write(initProtoField(2, iwant)) + if iwant.buffer.len > 0: + iwant.finish() + pb.write(initProtoField(2, iwant)) if control.graft.len > 0: var graft = initProtoBuffer() @@ -106,8 +108,9 @@ proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} = g.encodeGraft(graft) # write messages to protobuf - graft.finish() - pb.write(initProtoField(3, graft)) + if graft.buffer.len > 0: + graft.finish() + pb.write(initProtoField(3, graft)) if control.prune.len > 0: var prune = initProtoBuffer() @@ -115,8 +118,9 @@ proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} = p.encodePrune(prune) # write messages to protobuf - prune.finish() - pb.write(initProtoField(4, prune)) + if prune.buffer.len > 0: + prune.finish() + pb.write(initProtoField(4, prune)) proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} = trace "decoding control submessage" @@ -225,9 +229,11 @@ proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = for s in msg.subscriptions: var subs = initProtoBuffer() encodeSubs(s, subs) + # write subscriptions to protobuf - subs.finish() - result.write(initProtoField(1, subs)) + if subs.buffer.len > 0: + subs.finish() + result.write(initProtoField(1, subs)) if msg.messages.len > 0: var messages = initProtoBuffer() @@ -235,16 +241,18 @@ proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = encodeMessage(m, messages) # write messages to protobuf - messages.finish() - result.write(initProtoField(2, messages)) + if messages.buffer.len > 0: + messages.finish() + result.write(initProtoField(2, messages)) if msg.control.isSome: var control = initProtoBuffer() msg.control.get.encodeControl(control) # write messages to protobuf - control.finish() - result.write(initProtoField(3, control)) + if control.buffer.len > 0: + control.finish() + result.write(initProtoField(3, control)) if result.buffer.len > 0: result.finish() From 503a7ec1c5dd47f0bd31d15027e871064f610d4f Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Sun, 12 Jul 2020 11:14:49 +0900 Subject: [PATCH 12/23] disable arm64 builds for now (travis) --- .travis.yml | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index c116e1767..faef01911 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,13 +22,14 @@ matrix: - NPROC=2 before_install: - export GOPATH=$HOME/go - - os: linux - arch: arm64 - env: - - NPROC=6 # Worth trying more than 2 parallel jobs: https://travis-ci.community/t/no-cache-support-on-arm64/5416/8 - # (also used to get a different cache key than the amd64 one) - before_install: - - export GOPATH=$HOME/go + ## arm64 is very unreliable and slow, disabled for now + # - os: linux + # arch: arm64 + # env: + # - NPROC=6 # Worth trying more than 2 parallel jobs: https://travis-ci.community/t/no-cache-support-on-arm64/5416/8 + # # (also used to get a different cache key than the amd64 one) + # before_install: + # - export GOPATH=$HOME/go - os: osx env: - NPROC=2 From 181cf73ca75132203f369cde2262b19f8c3edc8e Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Sun, 12 Jul 2020 10:37:10 -0600 Subject: [PATCH 13/23] Drain buffer (#264) * drain lpchannel on reset * move drainBuffer to bufferstream --- libp2p/muxers/mplex/lpchannel.nim | 15 ++++++++------- libp2p/stream/bufferstream.nim | 9 +++++++++ libp2p/stream/lpstream.nim | 18 ++++++++++++++---- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 771a4e3a3..41be7cbad 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -189,14 +189,11 @@ proc closeRemote*(s: LPChannel) {.async.} = # stack = getStackTrace() trace "got EOF, closing channel" - - # wait for all data in the buffer to be consumed - while s.len > 0: - await s.dataReadEvent.wait() - s.dataReadEvent.clear() + await s.drainBuffer() s.isEof = true # set EOF immediately to prevent further reads await s.close() # close local end + # call to avoid leaks await procCall BufferStream(s).close() # close parent bufferstream trace "channel closed on EOF" @@ -227,7 +224,11 @@ method reset*(s: LPChannel) {.base, async, gcsafe.} = # might be dead already - reset is always # optimistic asyncCheck s.resetMessage() + + # drain the buffer before closing + await s.drainBuffer() await procCall BufferStream(s).close() + s.isEof = true s.closedLocal = true @@ -254,11 +255,11 @@ method close*(s: LPChannel) {.async, gcsafe.} = if s.atEof: # already closed by remote close parent buffer immediately await procCall BufferStream(s).close() except CancelledError as exc: - await s.reset() # reset on timeout + await s.reset() raise exc except CatchableError as exc: trace "exception closing channel" - await s.reset() # reset on timeout + await s.reset() trace "lpchannel closed local" diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index c15fa7bf5..41c51f6ed 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -198,6 +198,15 @@ method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} = await s.dataReadEvent.wait() s.dataReadEvent.clear() +proc drainBuffer*(s: BufferStream) {.async.} = + ## wait for all data in the buffer to be consumed + ## + + trace "draining buffer", len = s.len + while s.len > 0: + await s.dataReadEvent.wait() + s.dataReadEvent.clear() + method readOnce*(s: BufferStream, pbytes: pointer, nbytes: int): diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index f7a4eda80..68bd47ef1 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -102,22 +102,32 @@ method readOnce*(s: LPStream, doAssert(false, "not implemented!") proc readExactly*(s: LPStream, - pbytes: pointer, - nbytes: int): - Future[void] {.async.} = + pbytes: pointer, + nbytes: int): + Future[void] {.async.} = if s.atEof: raise newLPStreamEOFError() + logScope: + nbytes = nbytes + obName = s.objName + stack = getStackTrace() + oid = $s.oid + var pbuffer = cast[ptr UncheckedArray[byte]](pbytes) var read = 0 while read < nbytes and not(s.atEof()): read += await s.readOnce(addr pbuffer[read], nbytes - read) if read < nbytes: + trace "incomplete data received", read raise newLPStreamIncompleteError() -proc readLine*(s: LPStream, limit = 0, sep = "\r\n"): Future[string] {.async, deprecated: "todo".} = +proc readLine*(s: LPStream, + limit = 0, + sep = "\r\n"): Future[string] + {.async, deprecated: "todo".} = # TODO replace with something that exploits buffering better var lim = if limit <= 0: -1 else: limit var state = 0 From efb952f18b42a1ffbd4f622d24c405d16188ff97 Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Mon, 13 Jul 2020 15:43:07 +0300 Subject: [PATCH 14/23] [WIP] Minprotobuf refactoring (#259) * Minprotobuf initial commit * Fix noise. * Add signed integers support. Add checks for field number value. Remove some casts. * Fix compile errors. * Fix comments and constants. --- libp2p/crypto/crypto.nim | 164 +++--- libp2p/multiaddress.nim | 88 ++- libp2p/peerid.nim | 23 +- libp2p/protobuf/minprotobuf.nim | 652 ++++++++++++++++++++-- libp2p/protocols/identify.nim | 46 +- libp2p/protocols/pubsub/rpc/message.nim | 9 +- libp2p/protocols/pubsub/rpc/protobuf.nim | 450 ++++++++------- libp2p/protocols/secure/noise.nim | 8 +- tests/testminprotobuf.nim | 677 +++++++++++++++++++++++ tests/testnative.nim | 1 + 10 files changed, 1705 insertions(+), 413 deletions(-) create mode 100644 tests/testminprotobuf.nim diff --git a/libp2p/crypto/crypto.nim b/libp2p/crypto/crypto.nim index 78b7dd519..893a2245f 100644 --- a/libp2p/crypto/crypto.nim +++ b/libp2p/crypto/crypto.nim @@ -222,8 +222,8 @@ proc toBytes*(key: PrivateKey, data: var openarray[byte]): CryptoResult[int] = ## ## Returns number of bytes (octets) needed to store private key ``key``. var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() var blen = len(msg.buffer) if len(data) >= blen: @@ -236,8 +236,8 @@ proc toBytes*(key: PublicKey, data: var openarray[byte]): CryptoResult[int] = ## ## Returns number of bytes (octets) needed to store public key ``key``. var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() var blen = len(msg.buffer) if len(data) >= blen and blen > 0: @@ -256,8 +256,8 @@ proc getBytes*(key: PrivateKey): CryptoResult[seq[byte]] = ## Return private key ``key`` in binary form (using libp2p's protobuf ## serialization). var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() ok(msg.buffer) @@ -265,8 +265,8 @@ proc getBytes*(key: PublicKey): CryptoResult[seq[byte]] = ## Return public key ``key`` in binary form (using libp2p's protobuf ## serialization). var msg = initProtoBuffer() - msg.write(initProtoField(1, cast[uint64](key.scheme))) - msg.write(initProtoField(2, ? key.getRawBytes())) + msg.write(1, uint64(key.scheme)) + msg.write(2, ? key.getRawBytes()) msg.finish() ok(msg.buffer) @@ -283,33 +283,32 @@ proc init*[T: PrivateKey|PublicKey](key: var T, data: openarray[byte]): bool = var buffer: seq[byte] if len(data) > 0: var pb = initProtoBuffer(@data) - if pb.getVarintValue(1, id) != 0: - if pb.getBytes(2, buffer) != 0: - if cast[int8](id) in SupportedSchemesInt: - var scheme = cast[PKScheme](cast[int8](id)) - when key is PrivateKey: - var nkey = PrivateKey(scheme: scheme) - else: - var nkey = PublicKey(scheme: scheme) - case scheme: - of PKScheme.RSA: - if init(nkey.rsakey, buffer).isOk: - key = nkey - return true - of PKScheme.Ed25519: - if init(nkey.edkey, buffer): - key = nkey - return true - of PKScheme.ECDSA: - if init(nkey.eckey, buffer).isOk: - key = nkey - return true - of PKScheme.Secp256k1: - if init(nkey.skkey, buffer).isOk: - key = nkey - return true - else: - return false + if pb.getField(1, id) and pb.getField(2, buffer): + if cast[int8](id) in SupportedSchemesInt and len(buffer) > 0: + var scheme = cast[PKScheme](cast[int8](id)) + when key is PrivateKey: + var nkey = PrivateKey(scheme: scheme) + else: + var nkey = PublicKey(scheme: scheme) + case scheme: + of PKScheme.RSA: + if init(nkey.rsakey, buffer).isOk: + key = nkey + return true + of PKScheme.Ed25519: + if init(nkey.edkey, buffer): + key = nkey + return true + of PKScheme.ECDSA: + if init(nkey.eckey, buffer).isOk: + key = nkey + return true + of PKScheme.Secp256k1: + if init(nkey.skkey, buffer).isOk: + key = nkey + return true + else: + return false proc init*(sig: var Signature, data: openarray[byte]): bool = ## Initialize signature ``sig`` from raw binary form. @@ -727,11 +726,11 @@ proc createProposal*(nonce, pubkey: openarray[byte], ## ``exchanges``, comma-delimeted list of supported ciphers ``ciphers`` and ## comma-delimeted list of supported hashes ``hashes``. var msg = initProtoBuffer({WithUint32BeLength}) - msg.write(initProtoField(1, nonce)) - msg.write(initProtoField(2, pubkey)) - msg.write(initProtoField(3, exchanges)) - msg.write(initProtoField(4, ciphers)) - msg.write(initProtoField(5, hashes)) + msg.write(1, nonce) + msg.write(2, pubkey) + msg.write(3, exchanges) + msg.write(4, ciphers) + msg.write(5, hashes) msg.finish() shallowCopy(result, msg.buffer) @@ -744,19 +743,16 @@ proc decodeProposal*(message: seq[byte], nonce, pubkey: var seq[byte], ## ## Procedure returns ``true`` on success and ``false`` on error. var pb = initProtoBuffer(message) - if pb.getLengthValue(1, nonce) != -1 and - pb.getLengthValue(2, pubkey) != -1 and - pb.getLengthValue(3, exchanges) != -1 and - pb.getLengthValue(4, ciphers) != -1 and - pb.getLengthValue(5, hashes) != -1: - result = true + pb.getField(1, nonce) and pb.getField(2, pubkey) and + pb.getField(3, exchanges) and pb.getField(4, ciphers) and + pb.getField(5, hashes) proc createExchange*(epubkey, signature: openarray[byte]): seq[byte] = ## Create SecIO exchange message using ephemeral public key ``epubkey`` and ## signature of proposal blocks ``signature``. var msg = initProtoBuffer({WithUint32BeLength}) - msg.write(initProtoField(1, epubkey)) - msg.write(initProtoField(2, signature)) + msg.write(1, epubkey) + msg.write(2, signature) msg.finish() shallowCopy(result, msg.buffer) @@ -767,9 +763,7 @@ proc decodeExchange*(message: seq[byte], ## ## Procedure returns ``true`` on success and ``false`` on error. var pb = initProtoBuffer(message) - if pb.getLengthValue(1, pubkey) != -1 and - pb.getLengthValue(2, signature) != -1: - result = true + pb.getField(1, pubkey) and pb.getField(2, signature) ## Serialization/Deserialization helpers @@ -788,22 +782,27 @@ proc write*(vb: var VBuffer, sig: PrivateKey) {. ## Write Signature value ``sig`` to buffer ``vb``. vb.writeSeq(sig.getBytes().tryGet()) -proc initProtoField*(index: int, pubkey: PublicKey): ProtoField {. - raises: [Defect, ResultError[CryptoError]].} = - ## Initialize ProtoField with PublicKey ``pubkey``. - result = initProtoField(index, pubkey.getBytes().tryGet()) +proc write*[T: PublicKey|PrivateKey](pb: var ProtoBuffer, field: int, + key: T) {. + inline, raises: [Defect, ResultError[CryptoError]].} = + write(pb, field, key.getBytes().tryGet()) -proc initProtoField*(index: int, seckey: PrivateKey): ProtoField {. - raises: [Defect, ResultError[CryptoError]].} = - ## Initialize ProtoField with PrivateKey ``seckey``. - result = initProtoField(index, seckey.getBytes().tryGet()) +proc write*(pb: var ProtoBuffer, field: int, sig: Signature) {. + inline, raises: [Defect, ResultError[CryptoError]].} = + write(pb, field, sig.getBytes()) -proc initProtoField*(index: int, sig: Signature): ProtoField = +proc initProtoField*(index: int, key: PublicKey|PrivateKey): ProtoField {. + deprecated, raises: [Defect, ResultError[CryptoError]].} = + ## Initialize ProtoField with PublicKey/PrivateKey ``key``. + result = initProtoField(index, key.getBytes().tryGet()) + +proc initProtoField*(index: int, sig: Signature): ProtoField {.deprecated.} = ## Initialize ProtoField with Signature ``sig``. result = initProtoField(index, sig.getBytes()) -proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int = - ## Read ``PublicKey`` from ProtoBuf's message and validate it. +proc getValue*[T: PublicKey|PrivateKey](data: var ProtoBuffer, field: int, + value: var T): int {.deprecated.} = + ## Read PublicKey/PrivateKey from ProtoBuf's message and validate it. var buf: seq[byte] var key: PublicKey result = getLengthValue(data, field, buf) @@ -813,18 +812,8 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PublicKey): int = else: value = key -proc getValue*(data: var ProtoBuffer, field: int, value: var PrivateKey): int = - ## Read ``PrivateKey`` from ProtoBuf's message and validate it. - var buf: seq[byte] - var key: PrivateKey - result = getLengthValue(data, field, buf) - if result > 0: - if not key.init(buf): - result = -1 - else: - value = key - -proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int = +proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int {. + deprecated.} = ## Read ``Signature`` from ProtoBuf's message and validate it. var buf: seq[byte] var sig: Signature @@ -834,3 +823,30 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var Signature): int = result = -1 else: value = sig + +proc getField*[T: PublicKey|PrivateKey](pb: ProtoBuffer, field: int, + value: var T): bool = + var buffer: seq[byte] + var key: T + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + if key.init(buffer): + value = key + true + else: + false + +proc getField*(pb: ProtoBuffer, field: int, value: var Signature): bool = + var buffer: seq[byte] + var sig: Signature + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + if sig.init(buffer): + value = sig + true + else: + false diff --git a/libp2p/multiaddress.nim b/libp2p/multiaddress.nim index c6064f075..3d6d85eb6 100644 --- a/libp2p/multiaddress.nim +++ b/libp2p/multiaddress.nim @@ -14,9 +14,10 @@ import nativesockets import tables, strutils, stew/shims/net import chronos -import multicodec, multihash, multibase, transcoder, vbuffer, peerid +import multicodec, multihash, multibase, transcoder, vbuffer, peerid, + protobuf/minprotobuf import stew/[base58, base32, endians2, results] -export results +export results, minprotobuf, vbuffer type MAKind* = enum @@ -477,7 +478,8 @@ proc protoName*(ma: MultiAddress): MaResult[string] = else: ok($(proto.mcodec)) -proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] = +proc protoArgument*(ma: MultiAddress, + value: var openarray[byte]): MaResult[int] = ## Returns MultiAddress ``ma`` protocol argument value. ## ## If current MultiAddress do not have argument value, then result will be @@ -496,8 +498,8 @@ proc protoArgument*(ma: MultiAddress, value: var openarray[byte]): MaResult[int] var res: int if proto.kind == Fixed: res = proto.size - if len(value) >= res and - vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size: + if len(value) >= res and + vb.data.readArray(value.toOpenArray(0, proto.size - 1)) != proto.size: err("multiaddress: Decoding protocol error") else: ok(res) @@ -580,7 +582,8 @@ iterator items*(ma: MultiAddress): MaResult[MultiAddress] = let proto = CodeAddresses.getOrDefault(MultiCodec(header)) if proto.kind == None: - yield err(MaResult[MultiAddress], "Unsupported protocol '" & $header & "'") + yield err(MaResult[MultiAddress], "Unsupported protocol '" & + $header & "'") elif proto.kind == Fixed: data.setLen(proto.size) @@ -609,7 +612,8 @@ proc contains*(ma: MultiAddress, codec: MultiCodec): MaResult[bool] {.inline.} = return ok(true) ok(false) -proc `[]`*(ma: MultiAddress, codec: MultiCodec): MaResult[MultiAddress] {.inline.} = +proc `[]`*(ma: MultiAddress, + codec: MultiCodec): MaResult[MultiAddress] {.inline.} = ## Returns partial MultiAddress with MultiCodec ``codec`` and present in ## MultiAddress ``ma``. for item in ma.items: @@ -634,7 +638,8 @@ proc toString*(value: MultiAddress): MaResult[string] = return err("multiaddress: Unsupported protocol '" & $header & "'") if proto.kind in {Fixed, Length, Path}: if isNil(proto.coder.bufferToString): - return err("multiaddress: Missing protocol '" & $(proto.mcodec) & "' coder") + return err("multiaddress: Missing protocol '" & $(proto.mcodec) & + "' coder") if not proto.coder.bufferToString(vb.data, part): return err("multiaddress: Decoding protocol error") parts.add($(proto.mcodec)) @@ -729,12 +734,14 @@ proc init*( of None: raiseAssert "None checked above" -proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: PeerID): MaResult[MultiAddress] {.inline.} = +proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, + value: PeerID): MaResult[MultiAddress] {.inline.} = ## Initialize MultiAddress object from protocol id ``protocol`` and peer id ## ``value``. init(mtype, protocol, value.data) -proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, value: int): MaResult[MultiAddress] = +proc init*(mtype: typedesc[MultiAddress], protocol: MultiCodec, + value: int): MaResult[MultiAddress] = ## Initialize MultiAddress object from protocol id ``protocol`` and integer ## ``value``. This procedure can be used to instantiate ``tcp``, ``udp``, ## ``dccp`` and ``sctp`` MultiAddresses. @@ -759,7 +766,8 @@ proc getProtocol(name: string): MAProtocol {.inline.} = if mc != InvalidMultiCodec: result = CodeAddresses.getOrDefault(mc) -proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] = +proc init*(mtype: typedesc[MultiAddress], + value: string): MaResult[MultiAddress] = ## Initialize MultiAddress object from string representation ``value``. var parts = value.trimRight('/').split('/') if len(parts[0]) != 0: @@ -776,7 +784,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] else: if proto.kind in {Fixed, Length, Path}: if isNil(proto.coder.stringToBuffer): - return err("multiaddress: Missing protocol '" & part & "' transcoder") + return err("multiaddress: Missing protocol '" & + part & "' transcoder") if offset + 1 >= len(parts): return err("multiaddress: Missing protocol '" & part & "' argument") @@ -785,14 +794,16 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] res.data.write(proto.mcodec) let res = proto.coder.stringToBuffer(parts[offset + 1], res.data) if not res: - return err("multiaddress: Error encoding `" & part & "/" & parts[offset + 1] & "`") + return err("multiaddress: Error encoding `" & part & "/" & + parts[offset + 1] & "`") offset += 2 elif proto.kind == Path: var path = "/" & (parts[(offset + 1)..^1].join("/")) res.data.write(proto.mcodec) if not proto.coder.stringToBuffer(path, res.data): - return err("multiaddress: Error encoding `" & part & "/" & path & "`") + return err("multiaddress: Error encoding `" & part & "/" & + path & "`") break elif proto.kind == Marker: @@ -801,8 +812,8 @@ proc init*(mtype: typedesc[MultiAddress], value: string): MaResult[MultiAddress] res.data.finish() ok(res) - -proc init*(mtype: typedesc[MultiAddress], data: openarray[byte]): MaResult[MultiAddress] = +proc init*(mtype: typedesc[MultiAddress], + data: openarray[byte]): MaResult[MultiAddress] = ## Initialize MultiAddress with array of bytes ``data``. if len(data) == 0: err("multiaddress: Address could not be empty!") @@ -836,10 +847,12 @@ proc init*(mtype: typedesc[MultiAddress], var data = initVBuffer() data.write(familyProto.mcodec) var written = familyProto.coder.stringToBuffer($address, data) - doAssert written, "Merely writing a string to a buffer should always be possible" + doAssert written, + "Merely writing a string to a buffer should always be possible" data.write(protoProto.mcodec) written = protoProto.coder.stringToBuffer($port, data) - doAssert written, "Merely writing a string to a buffer should always be possible" + doAssert written, + "Merely writing a string to a buffer should always be possible" data.finish() MultiAddress(data: data) @@ -890,14 +903,16 @@ proc append*(m1: var MultiAddress, m2: MultiAddress): MaResult[void] = else: ok() -proc `&`*(m1, m2: MultiAddress): MultiAddress {.raises: [Defect, ResultError[string]].} = +proc `&`*(m1, m2: MultiAddress): MultiAddress {. + raises: [Defect, ResultError[string]].} = ## Concatenates two addresses ``m1`` and ``m2``, and returns result. ## ## This procedure performs validation of concatenated result and can raise ## exception on error. concat(m1, m2).tryGet() -proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {.raises: [Defect, ResultError[string]].} = +proc `&=`*(m1: var MultiAddress, m2: MultiAddress) {. + raises: [Defect, ResultError[string]].} = ## Concatenates two addresses ``m1`` and ``m2``. ## ## This procedure performs validation of concatenated result and can raise @@ -1005,3 +1020,36 @@ proc `$`*(pat: MaPattern): string = result = "(" & sub.join("|") & ")" elif pat.operator == Eq: result = $pat.value + +proc write*(pb: var ProtoBuffer, field: int, value: MultiAddress) {.inline.} = + write(pb, field, value.data.buffer) + +proc getField*(pb: var ProtoBuffer, field: int, + value: var MultiAddress): bool {.inline.} = + var buffer: seq[byte] + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + let ma = MultiAddress.init(buffer) + if ma.isOk(): + value = ma.get() + true + else: + false + +proc getRepeatedField*(pb: var ProtoBuffer, field: int, + value: var seq[MultiAddress]): bool {.inline.} = + var items: seq[seq[byte]] + value.setLen(0) + if not(getRepeatedField(pb, field, items)): + return false + if len(items) == 0: + return true + for item in items: + let ma = MultiAddress.init(item) + if ma.isOk(): + value.add(ma.get()) + else: + value.setLen(0) + return false diff --git a/libp2p/peerid.nim b/libp2p/peerid.nim index e20417d0c..5c81c664b 100644 --- a/libp2p/peerid.nim +++ b/libp2p/peerid.nim @@ -200,11 +200,12 @@ proc write*(vb: var VBuffer, pid: PeerID) {.inline.} = ## Write PeerID value ``peerid`` to buffer ``vb``. vb.writeSeq(pid.data) -proc initProtoField*(index: int, pid: PeerID): ProtoField = +proc initProtoField*(index: int, pid: PeerID): ProtoField {.deprecated.} = ## Initialize ProtoField with PeerID ``value``. result = initProtoField(index, pid.data) -proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int = +proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int {. + deprecated.} = ## Read ``PeerID`` from ProtoBuf's message and validate it. var pid: PeerID result = getLengthValue(data, field, pid.data) @@ -213,3 +214,21 @@ proc getValue*(data: var ProtoBuffer, field: int, value: var PeerID): int = result = -1 else: value = pid + +proc write*(pb: var ProtoBuffer, field: int, pid: PeerID) = + ## Write PeerID value ``peerid`` to object ``pb`` using ProtoBuf's encoding. + write(pb, field, pid.data) + +proc getField*(pb: ProtoBuffer, field: int, pid: var PeerID): bool = + ## Read ``PeerID`` from ProtoBuf's message and validate it + var buffer: seq[byte] + var peerId: PeerID + if not(getField(pb, field, buffer)): + return false + if len(buffer) == 0: + return false + if peerId.init(buffer): + pid = peerId + true + else: + false diff --git a/libp2p/protobuf/minprotobuf.nim b/libp2p/protobuf/minprotobuf.nim index caef22c1f..5a00c4726 100644 --- a/libp2p/protobuf/minprotobuf.nim +++ b/libp2p/protobuf/minprotobuf.nim @@ -11,7 +11,7 @@ {.push raises: [Defect].} -import ../varint +import ../varint, stew/endians2 const MaxMessageSize* = 1'u shl 22 @@ -32,10 +32,14 @@ type offset*: int length*: int + ProtoHeader* = object + wire*: ProtoFieldKind + index*: uint64 + ProtoField* = object ## Protobuf's message field representation object - index: int - case kind: ProtoFieldKind + index*: int + case kind*: ProtoFieldKind of Varint: vint*: uint64 of Fixed64: @@ -47,13 +51,35 @@ type of StartGroup, EndGroup: discard -template protoHeader*(index: int, wire: ProtoFieldKind): uint = - ## Get protobuf's field header integer for ``index`` and ``wire``. - ((uint(index) shl 3) or cast[uint](wire)) + ProtoResult {.pure.} = enum + VarintDecodeError, + MessageIncompleteError, + BufferOverflowError, + MessageSizeTooBigError, + NoError -template protoHeader*(field: ProtoField): uint = + ProtoScalar* = uint | uint32 | uint64 | zint | zint32 | zint64 | + hint | hint32 | hint64 | float32 | float64 + +const + SupportedWireTypes* = { + int(ProtoFieldKind.Varint), + int(ProtoFieldKind.Fixed64), + int(ProtoFieldKind.Length), + int(ProtoFieldKind.Fixed32) + } + +template checkFieldNumber*(i: int) = + doAssert((i > 0 and i < (1 shl 29)) and not(i >= 19000 and i <= 19999), + "Incorrect or reserved field number") + +template getProtoHeader*(index: int, wire: ProtoFieldKind): uint64 = + ## Get protobuf's field header integer for ``index`` and ``wire``. + ((uint64(index) shl 3) or uint64(wire)) + +template getProtoHeader*(field: ProtoField): uint64 = ## Get protobuf's field header integer for ``field``. - ((uint(field.index) shl 3) or cast[uint](field.kind)) + ((uint64(field.index) shl 3) or uint64(field.kind)) template toOpenArray*(pb: ProtoBuffer): untyped = toOpenArray(pb.buffer, pb.offset, len(pb.buffer) - 1) @@ -72,20 +98,20 @@ template getLen*(pb: ProtoBuffer): int = proc vsizeof*(field: ProtoField): int {.inline.} = ## Returns number of bytes required to store protobuf's field ``field``. - result = vsizeof(protoHeader(field)) case field.kind of ProtoFieldKind.Varint: - result += vsizeof(field.vint) + vsizeof(getProtoHeader(field)) + vsizeof(field.vint) of ProtoFieldKind.Fixed64: - result += sizeof(field.vfloat64) + vsizeof(getProtoHeader(field)) + sizeof(field.vfloat64) of ProtoFieldKind.Fixed32: - result += sizeof(field.vfloat32) + vsizeof(getProtoHeader(field)) + sizeof(field.vfloat32) of ProtoFieldKind.Length: - result += vsizeof(uint(len(field.vbuffer))) + len(field.vbuffer) + vsizeof(getProtoHeader(field)) + vsizeof(uint64(len(field.vbuffer))) + + len(field.vbuffer) else: - discard + 0 -proc initProtoField*(index: int, value: SomeVarint): ProtoField = +proc initProtoField*(index: int, value: SomeVarint): ProtoField {.deprecated.} = ## Initialize ProtoField with integer value. result = ProtoField(kind: Varint, index: index) when type(value) is uint64: @@ -93,26 +119,28 @@ proc initProtoField*(index: int, value: SomeVarint): ProtoField = else: result.vint = cast[uint64](value) -proc initProtoField*(index: int, value: bool): ProtoField = +proc initProtoField*(index: int, value: bool): ProtoField {.deprecated.} = ## Initialize ProtoField with integer value. result = ProtoField(kind: Varint, index: index) result.vint = byte(value) -proc initProtoField*(index: int, value: openarray[byte]): ProtoField = +proc initProtoField*(index: int, + value: openarray[byte]): ProtoField {.deprecated.} = ## Initialize ProtoField with bytes array. result = ProtoField(kind: Length, index: index) if len(value) > 0: result.vbuffer = newSeq[byte](len(value)) copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) -proc initProtoField*(index: int, value: string): ProtoField = +proc initProtoField*(index: int, value: string): ProtoField {.deprecated.} = ## Initialize ProtoField with string. result = ProtoField(kind: Length, index: index) if len(value) > 0: result.vbuffer = newSeq[byte](len(value)) copyMem(addr result.vbuffer[0], unsafeAddr value[0], len(value)) -proc initProtoField*(index: int, value: ProtoBuffer): ProtoField {.inline.} = +proc initProtoField*(index: int, + value: ProtoBuffer): ProtoField {.deprecated, inline.} = ## Initialize ProtoField with nested message stored in ``value``. ## ## Note: This procedure performs shallow copy of ``value`` sequence. @@ -127,6 +155,13 @@ proc initProtoBuffer*(data: seq[byte], offset = 0, result.offset = offset result.options = options +proc initProtoBuffer*(data: openarray[byte], offset = 0, + options: set[ProtoFlags] = {}): ProtoBuffer = + ## Initialize ProtoBuffer with copy of ``data``. + result.buffer = @data + result.offset = offset + result.options = options + proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = ## Initialize ProtoBuffer with new sequence of capacity ``cap``. result.buffer = newSeqOfCap[byte](128) @@ -138,16 +173,134 @@ proc initProtoBuffer*(options: set[ProtoFlags] = {}): ProtoBuffer = result.offset = 10 elif {WithUint32LeLength, WithUint32BeLength} * options != {}: # Our buffer will start from position 4, so we can store length of buffer - # in [0, 9]. + # in [0, 3]. result.buffer.setLen(4) result.offset = 4 -proc write*(pb: var ProtoBuffer, field: ProtoField) = +proc write*[T: ProtoScalar](pb: var ProtoBuffer, + field: int, value: T) = + checkFieldNumber(field) + var length = 0 + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Varint)) + + vsizeof(value) + let header = ProtoFieldKind.Varint + elif T is float32: + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed32)) + + sizeof(T) + let header = ProtoFieldKind.Fixed32 + elif T is float64: + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Fixed64)) + + sizeof(T) + let header = ProtoFieldKind.Fixed64 + + pb.buffer.setLen(len(pb.buffer) + flength) + + let hres = PB.putUVarint(pb.toOpenArray(), length, + getProtoHeader(field, header)) + doAssert(hres.isOk()) + pb.offset += length + when (T is uint64) or (T is uint32) or (T is uint): + let vres = PB.putUVarint(pb.toOpenArray(), length, value) + doAssert(vres.isOk()) + pb.offset += length + elif (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + let vres = putSVarint(pb.toOpenArray(), length, value) + doAssert(vres.isOk()) + pb.offset += length + elif T is float32: + doAssert(pb.isEnough(sizeof(T))) + let u32 = cast[uint32](value) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE() + pb.offset += sizeof(T) + elif T is float64: + doAssert(pb.isEnough(sizeof(T))) + let u64 = cast[uint64](value) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE() + pb.offset += sizeof(T) + +proc writePacked*[T: ProtoScalar](pb: var ProtoBuffer, field: int, + value: openarray[T]) = + checkFieldNumber(field) + var length = 0 + let dlength = + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + var res = 0 + for item in value: + res += vsizeof(item) + res + elif (T is float32) or (T is float64): + len(value) * sizeof(T) + + let header = getProtoHeader(field, ProtoFieldKind.Length) + let flength = vsizeof(header) + vsizeof(uint64(dlength)) + dlength + pb.buffer.setLen(len(pb.buffer) + flength) + let hres = PB.putUVarint(pb.toOpenArray(), length, header) + doAssert(hres.isOk()) + pb.offset += length + length = 0 + let lres = PB.putUVarint(pb.toOpenArray(), length, uint64(dlength)) + doAssert(lres.isOk()) + pb.offset += length + for item in value: + when (T is uint64) or (T is uint32) or (T is uint): + length = 0 + let vres = PB.putUVarint(pb.toOpenArray(), length, item) + doAssert(vres.isOk()) + pb.offset += length + elif (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + length = 0 + let vres = PB.putSVarint(pb.toOpenArray(), length, item) + doAssert(vres.isOk()) + pb.offset += length + elif T is float32: + doAssert(pb.isEnough(sizeof(T))) + let u32 = cast[uint32](item) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u32.toBytesLE() + pb.offset += sizeof(T) + elif T is float64: + doAssert(pb.isEnough(sizeof(T))) + let u64 = cast[uint64](item) + pb.buffer[pb.offset ..< pb.offset + sizeof(T)] = u64.toBytesLE() + pb.offset += sizeof(T) + +proc write*[T: byte|char](pb: var ProtoBuffer, field: int, + value: openarray[T]) = + checkFieldNumber(field) + var length = 0 + let flength = vsizeof(getProtoHeader(field, ProtoFieldKind.Length)) + + vsizeof(uint64(len(value))) + len(value) + pb.buffer.setLen(len(pb.buffer) + flength) + let hres = PB.putUVarint(pb.toOpenArray(), length, + getProtoHeader(field, ProtoFieldKind.Length)) + doAssert(hres.isOk()) + pb.offset += length + let lres = PB.putUVarint(pb.toOpenArray(), length, + uint64(len(value))) + doAssert(lres.isOk()) + pb.offset += length + if len(value) > 0: + doAssert(pb.isEnough(len(value))) + copyMem(addr pb.buffer[pb.offset], unsafeAddr value[0], len(value)) + pb.offset += len(value) + +proc write*(pb: var ProtoBuffer, field: int, value: ProtoBuffer) {.inline.} = + ## Encode Protobuf's sub-message ``value`` and store it to protobuf's buffer + ## ``pb`` with field number ``field``. + write(pb, field, value.buffer) + +proc write*(pb: var ProtoBuffer, field: ProtoField) {.deprecated.} = ## Encode protobuf's field ``field`` and store it to protobuf's buffer ``pb``. var length = 0 var res: VarintResult[void] pb.buffer.setLen(len(pb.buffer) + vsizeof(field)) - res = PB.putUVarint(pb.toOpenArray(), length, protoHeader(field)) + res = PB.putUVarint(pb.toOpenArray(), length, getProtoHeader(field)) doAssert(res.isOk()) pb.offset += length case field.kind @@ -199,31 +352,440 @@ proc finish*(pb: var ProtoBuffer) = pb.offset = pos elif WithUint32BeLength in pb.options: let size = uint(len(pb.buffer) - 4) - pb.buffer[0] = byte((size shr 24) and 0xFF'u) - pb.buffer[1] = byte((size shr 16) and 0xFF'u) - pb.buffer[2] = byte((size shr 8) and 0xFF'u) - pb.buffer[3] = byte(size and 0xFF'u) + pb.buffer[0 ..< 4] = toBytesBE(uint32(size)) pb.offset = 4 elif WithUint32LeLength in pb.options: let size = uint(len(pb.buffer) - 4) - pb.buffer[0] = byte(size and 0xFF'u) - pb.buffer[1] = byte((size shr 8) and 0xFF'u) - pb.buffer[2] = byte((size shr 16) and 0xFF'u) - pb.buffer[3] = byte((size shr 24) and 0xFF'u) + pb.buffer[0 ..< 4] = toBytesLE(uint32(size)) pb.offset = 4 else: pb.offset = 0 +proc getHeader(data: var ProtoBuffer, header: var ProtoHeader): bool = + var length = 0 + var hdr = 0'u64 + if PB.getUVarint(data.toOpenArray(), length, hdr).isOk(): + let index = uint64(hdr shr 3) + let wire = hdr and 0x07 + if wire in SupportedWireTypes: + data.offset += length + header = ProtoHeader(index: index, wire: cast[ProtoFieldKind](wire)) + true + else: + false + else: + false + +proc skipValue(data: var ProtoBuffer, header: ProtoHeader): bool = + case header.wire + of ProtoFieldKind.Varint: + var length = 0 + var value = 0'u64 + if PB.getUVarint(data.toOpenArray(), length, value).isOk(): + data.offset += length + true + else: + false + of ProtoFieldKind.Fixed32: + if data.isEnough(sizeof(uint32)): + data.offset += sizeof(uint32) + true + else: + false + of ProtoFieldKind.Fixed64: + if data.isEnough(sizeof(uint64)): + data.offset += sizeof(uint64) + true + else: + false + of ProtoFieldKind.Length: + var length = 0 + var bsize = 0'u64 + if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): + data.offset += length + if bsize <= uint64(MaxMessageSize): + if data.isEnough(int(bsize)): + data.offset += int(bsize) + true + else: + false + else: + false + else: + false + of ProtoFieldKind.StartGroup, ProtoFieldKind.EndGroup: + false + +proc getValue[T: ProtoScalar](data: var ProtoBuffer, + header: ProtoHeader, + outval: var T): ProtoResult = + when (T is uint64) or (T is uint32) or (T is uint): + doAssert(header.wire == ProtoFieldKind.Varint) + var length = 0 + var value = T(0) + if PB.getUVarint(data.toOpenArray(), length, value).isOk(): + data.offset += length + outval = value + ProtoResult.NoError + else: + ProtoResult.VarintDecodeError + elif (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + doAssert(header.wire == ProtoFieldKind.Varint) + var length = 0 + var value = T(0) + if getSVarint(data.toOpenArray(), length, value).isOk(): + data.offset += length + outval = value + ProtoResult.NoError + else: + ProtoResult.VarintDecodeError + elif T is float32: + doAssert(header.wire == ProtoFieldKind.Fixed32) + if data.isEnough(sizeof(float32)): + outval = cast[float32](fromBytesLE(uint32, data.toOpenArray())) + data.offset += sizeof(float32) + ProtoResult.NoError + else: + ProtoResult.MessageIncompleteError + elif T is float64: + doAssert(header.wire == ProtoFieldKind.Fixed64) + if data.isEnough(sizeof(float64)): + outval = cast[float64](fromBytesLE(uint64, data.toOpenArray())) + data.offset += sizeof(float64) + ProtoResult.NoError + else: + ProtoResult.MessageIncompleteError + +proc getValue[T:byte|char](data: var ProtoBuffer, header: ProtoHeader, + outBytes: var openarray[T], + outLength: var int): ProtoResult = + doAssert(header.wire == ProtoFieldKind.Length) + var length = 0 + var bsize = 0'u64 + + outLength = 0 + if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): + data.offset += length + if bsize <= uint64(MaxMessageSize): + if data.isEnough(int(bsize)): + outLength = int(bsize) + if len(outBytes) >= int(bsize): + if bsize > 0'u64: + copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) + data.offset += int(bsize) + ProtoResult.NoError + else: + # Buffer overflow should not be critical failure + data.offset += int(bsize) + ProtoResult.BufferOverflowError + else: + ProtoResult.MessageIncompleteError + else: + ProtoResult.MessageSizeTooBigError + else: + ProtoResult.VarintDecodeError + +proc getValue[T:seq[byte]|string](data: var ProtoBuffer, header: ProtoHeader, + outBytes: var T): ProtoResult = + doAssert(header.wire == ProtoFieldKind.Length) + var length = 0 + var bsize = 0'u64 + outBytes.setLen(0) + + if PB.getUVarint(data.toOpenArray(), length, bsize).isOk(): + data.offset += length + if bsize <= uint64(MaxMessageSize): + if data.isEnough(int(bsize)): + outBytes.setLen(bsize) + if bsize > 0'u64: + copyMem(addr outBytes[0], addr data.buffer[data.offset], int(bsize)) + data.offset += int(bsize) + ProtoResult.NoError + else: + ProtoResult.MessageIncompleteError + else: + ProtoResult.MessageSizeTooBigError + else: + ProtoResult.VarintDecodeError + +proc getField*[T: ProtoScalar](data: ProtoBuffer, field: int, + output: var T): bool = + checkFieldNumber(field) + var value: T + var res = false + var pb = data + output = T(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output = T(0) + return false + let wireCheck = + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + header.wire == ProtoFieldKind.Varint + elif T is float32: + header.wire == ProtoFieldKind.Fixed32 + elif T is float64: + header.wire == ProtoFieldKind.Fixed64 + if header.index == uint64(field): + if wireCheck: + let r = getValue(pb, header, value) + case r + of ProtoResult.NoError: + res = true + output = value + else: + return false + else: + # We are ignoring wire types different from what we expect, because it + # is how `protoc` is working. + if not(skipValue(pb, header)): + output = T(0) + return false + else: + if not(skipValue(pb, header)): + output = T(0) + return false + res + +proc getField*[T: byte|char](data: ProtoBuffer, field: int, + output: var openarray[T], + outlen: var int): bool = + checkFieldNumber(field) + var pb = data + var res = false + + outlen = 0 + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + if len(output) > 0: + zeroMem(addr output[0], len(output)) + outlen = 0 + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + let r = getValue(pb, header, output, outlen) + case r + of ProtoResult.NoError: + res = true + of ProtoResult.BufferOverflowError: + # Buffer overflow error is not critical error, we still can get + # field values with proper size. + discard + else: + if len(output) > 0: + zeroMem(addr output[0], len(output)) + return false + else: + # We are ignoring wire types different from ProtoFieldKind.Length, + # because it is how `protoc` is working. + if not(skipValue(pb, header)): + if len(output) > 0: + zeroMem(addr output[0], len(output)) + outlen = 0 + return false + else: + if not(skipValue(pb, header)): + if len(output) > 0: + zeroMem(addr output[0], len(output)) + outlen = 0 + return false + + res + +proc getField*[T: seq[byte]|string](data: ProtoBuffer, field: int, + output: var T): bool = + checkFieldNumber(field) + var res = false + var pb = data + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + let r = getValue(pb, header, output) + case r + of ProtoResult.NoError: + res = true + of ProtoResult.BufferOverflowError: + # Buffer overflow error is not critical error, we still can get + # field values with proper size. + discard + else: + output.setLen(0) + return false + else: + # We are ignoring wire types different from ProtoFieldKind.Length, + # because it is how `protoc` is working. + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + res + +proc getField*(pb: ProtoBuffer, field: int, output: var ProtoBuffer): bool {. + inline.} = + var buffer: seq[byte] + if pb.getField(field, buffer): + output = initProtoBuffer(buffer) + true + else: + false + +proc getRepeatedField*[T: seq[byte]|string](data: ProtoBuffer, field: int, + output: var seq[T]): bool = + checkFieldNumber(field) + var pb = data + output.setLen(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + var item: T + let r = getValue(pb, header, item) + case r + of ProtoResult.NoError: + output.add(item) + else: + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + if len(output) > 0: + true + else: + false + +proc getRepeatedField*[T: uint64|float32|float64](data: ProtoBuffer, + field: int, + output: var seq[T]): bool = + checkFieldNumber(field) + var pb = data + output.setLen(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire in {ProtoFieldKind.Varint, ProtoFieldKind.Fixed32, + ProtoFieldKind.Fixed64}: + var item: T + let r = getValue(pb, header, item) + case r + of ProtoResult.NoError: + output.add(item) + else: + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + if len(output) > 0: + true + else: + false + +proc getPackedRepeatedField*[T: ProtoScalar](data: ProtoBuffer, field: int, + output: var seq[T]): bool = + checkFieldNumber(field) + var pb = data + output.setLen(0) + + while not(pb.isEmpty()): + var header: ProtoHeader + if not(pb.getHeader(header)): + output.setLen(0) + return false + + if header.index == uint64(field): + if header.wire == ProtoFieldKind.Length: + var arritem: seq[byte] + let rarr = getValue(pb, header, arritem) + case rarr + of ProtoResult.NoError: + var pbarr = initProtoBuffer(arritem) + let itemHeader = + when (T is uint64) or (T is uint32) or (T is uint) or + (T is zint64) or (T is zint32) or (T is zint) or + (T is hint64) or (T is hint32) or (T is hint): + ProtoHeader(wire: ProtoFieldKind.Varint) + elif T is float32: + ProtoHeader(wire: ProtoFieldKind.Fixed32) + elif T is float64: + ProtoHeader(wire: ProtoFieldKind.Fixed64) + while not(pbarr.isEmpty()): + var item: T + let res = getValue(pbarr, itemHeader, item) + case res + of ProtoResult.NoError: + output.add(item) + else: + output.setLen(0) + return false + else: + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + else: + if not(skipValue(pb, header)): + output.setLen(0) + return false + + if len(output) > 0: + true + else: + false + proc getVarintValue*(data: var ProtoBuffer, field: int, - value: var SomeVarint): int = + value: var SomeVarint): int {.deprecated.} = ## Get value of `Varint` type. var length = 0 var header = 0'u64 var soffset = data.offset - if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk(): + if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), + length, header).isOk(): data.offset += length - if header == protoHeader(field, Varint): + if header == getProtoHeader(field, Varint): if not data.isEmpty(): when type(value) is int32 or type(value) is int64 or type(value) is int: let res = getSVarint(data.toOpenArray(), length, value) @@ -237,7 +799,7 @@ proc getVarintValue*(data: var ProtoBuffer, field: int, data.offset = soffset proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, - buffer: var T): int = + buffer: var T): int {.deprecated.} = ## Get value of `Length` type. var length = 0 var header = 0'u64 @@ -245,10 +807,12 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, var soffset = data.offset result = -1 buffer.setLen(0) - if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, header).isOk(): + if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), + length, header).isOk(): data.offset += length - if header == protoHeader(field, Length): - if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), length, ssize).isOk(): + if header == getProtoHeader(field, Length): + if not data.isEmpty() and PB.getUVarint(data.toOpenArray(), + length, ssize).isOk(): data.offset += length if ssize <= MaxMessageSize and data.isEnough(int(ssize)): buffer.setLen(ssize) @@ -262,16 +826,16 @@ proc getLengthValue*[T: string|seq[byte]](data: var ProtoBuffer, field: int, data.offset = soffset proc getBytes*(data: var ProtoBuffer, field: int, - buffer: var seq[byte]): int {.inline.} = + buffer: var seq[byte]): int {.deprecated, inline.} = ## Get value of `Length` type as bytes. result = getLengthValue(data, field, buffer) proc getString*(data: var ProtoBuffer, field: int, - buffer: var string): int {.inline.} = + buffer: var string): int {.deprecated, inline.} = ## Get value of `Length` type as string. result = getLengthValue(data, field, buffer) -proc enterSubmessage*(pb: var ProtoBuffer): int = +proc enterSubmessage*(pb: var ProtoBuffer): int {.deprecated.} = ## Processes protobuf's sub-message and adjust internal offset to enter ## inside of sub-message. Returns field index of sub-message field or ## ``0`` on error. @@ -280,10 +844,12 @@ proc enterSubmessage*(pb: var ProtoBuffer): int = var msize = 0'u64 var soffset = pb.offset - if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, header).isOk(): + if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), + length, header).isOk(): pb.offset += length if (header and 0x07'u64) == cast[uint64](ProtoFieldKind.Length): - if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), length, msize).isOk(): + if not pb.isEmpty() and PB.getUVarint(pb.toOpenArray(), + length, msize).isOk(): pb.offset += length if msize <= MaxMessageSize and pb.isEnough(int(msize)): pb.length = int(msize) @@ -292,7 +858,7 @@ proc enterSubmessage*(pb: var ProtoBuffer): int = # Restore offset on error pb.offset = soffset -proc skipSubmessage*(pb: var ProtoBuffer) = +proc skipSubmessage*(pb: var ProtoBuffer) {.deprecated.} = ## Skip current protobuf's sub-message and adjust internal offset to the ## end of sub-message. doAssert(pb.length != 0) diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index 735d740af..a998ea983 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -47,61 +47,49 @@ type proc encodeMsg*(peerInfo: PeerInfo, observedAddr: Multiaddress): ProtoBuffer = result = initProtoBuffer() - result.write(initProtoField(1, peerInfo.publicKey.get().getBytes().tryGet())) + result.write(1, peerInfo.publicKey.get().getBytes().tryGet()) for ma in peerInfo.addrs: - result.write(initProtoField(2, ma.data.buffer)) + result.write(2, ma.data.buffer) for proto in peerInfo.protocols: - result.write(initProtoField(3, proto)) + result.write(3, proto) - result.write(initProtoField(4, observedAddr.data.buffer)) + result.write(4, observedAddr.data.buffer) let protoVersion = ProtoVersion - result.write(initProtoField(5, protoVersion)) + result.write(5, protoVersion) let agentVersion = AgentVersion - result.write(initProtoField(6, agentVersion)) + result.write(6, agentVersion) result.finish() proc decodeMsg*(buf: seq[byte]): IdentifyInfo = var pb = initProtoBuffer(buf) - result.pubKey = none(PublicKey) var pubKey: PublicKey - if pb.getValue(1, pubKey) > 0: + if pb.getField(1, pubKey): trace "read public key from message", pubKey = ($pubKey).shortLog result.pubKey = some(pubKey) - result.addrs = newSeq[MultiAddress]() - var address = newSeq[byte]() - while pb.getBytes(2, address) > 0: - if len(address) != 0: - var copyaddr = address - var ma = MultiAddress.init(copyaddr).tryGet() - result.addrs.add(ma) - trace "read address bytes from message", address = ma - address.setLen(0) + if pb.getRepeatedField(2, result.addrs): + trace "read addresses from message", addresses = result.addrs - var proto = "" - while pb.getString(3, proto) > 0: - trace "read proto from message", proto = proto - result.protos.add(proto) - proto = "" + if pb.getRepeatedField(3, result.protos): + trace "read protos from message", protocols = result.protos - var observableAddr = newSeq[byte]() - if pb.getBytes(4, observableAddr) > 0: # attempt to read the observed addr - var ma = MultiAddress.init(observableAddr).tryGet() - trace "read observedAddr from message", address = ma - result.observedAddr = some(ma) + var observableAddr: MultiAddress + if pb.getField(4, observableAddr): + trace "read observableAddr from message", address = observableAddr + result.observedAddr = some(observableAddr) var protoVersion = "" - if pb.getString(5, protoVersion) > 0: + if pb.getField(5, protoVersion): trace "read protoVersion from message", protoVersion = protoVersion result.protoVersion = some(protoVersion) var agentVersion = "" - if pb.getString(6, agentVersion) > 0: + if pb.getField(6, agentVersion): trace "read agentVersion from message", agentVersion = agentVersion result.agentVersion = some(agentVersion) diff --git a/libp2p/protocols/pubsub/rpc/message.nim b/libp2p/protocols/pubsub/rpc/message.nim index d203035d4..9ff941853 100644 --- a/libp2p/protocols/pubsub/rpc/message.nim +++ b/libp2p/protocols/pubsub/rpc/message.nim @@ -32,9 +32,7 @@ func defaultMsgIdProvider*(m: Message): string = byteutils.toHex(m.seqno) & m.fromPeer.pretty proc sign*(msg: Message, p: PeerInfo): seq[byte] {.gcsafe, raises: [ResultError[CryptoError], Defect].} = - var buff = initProtoBuffer() - encodeMessage(msg, buff) - p.privateKey.sign(PubSubPrefix & buff.buffer).tryGet().getBytes() + p.privateKey.sign(PubSubPrefix & encodeMessage(msg)).tryGet().getBytes() proc verify*(m: Message, p: PeerInfo): bool = if m.signature.len > 0 and m.key.len > 0: @@ -42,14 +40,11 @@ proc verify*(m: Message, p: PeerInfo): bool = msg.signature = @[] msg.key = @[] - var buff = initProtoBuffer() - encodeMessage(msg, buff) - var remote: Signature var key: PublicKey if remote.init(m.signature) and key.init(m.key): trace "verifying signature", remoteSignature = remote - result = remote.verify(PubSubPrefix & buff.buffer, key) + result = remote.verify(PubSubPrefix & encodeMessage(msg), key) if result: libp2p_pubsub_sig_verify_success.inc() diff --git a/libp2p/protocols/pubsub/rpc/protobuf.nim b/libp2p/protocols/pubsub/rpc/protobuf.nim index 922546b20..c5a3eb309 100644 --- a/libp2p/protocols/pubsub/rpc/protobuf.nim +++ b/libp2p/protocols/pubsub/rpc/protobuf.nim @@ -14,265 +14,247 @@ import messages, ../../../utility, ../../../protobuf/minprotobuf -proc encodeGraft*(graft: ControlGraft, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, graft.topicID)) +proc write*(pb: var ProtoBuffer, field: int, graft: ControlGraft) = + var ipb = initProtoBuffer() + ipb.write(1, graft.topicID) + ipb.finish() + pb.write(field, ipb) -proc decodeGraft*(pb: var ProtoBuffer): seq[ControlGraft] {.gcsafe.} = - trace "decoding graft msg", buffer = pb.buffer.shortLog - while true: - var topic: string - if pb.getString(1, topic) < 0: - break +proc write*(pb: var ProtoBuffer, field: int, prune: ControlPrune) = + var ipb = initProtoBuffer() + ipb.write(1, prune.topicID) + ipb.finish() + pb.write(field, ipb) - trace "read topic field from graft msg", topicID = topic - result.add(ControlGraft(topicID: topic)) - -proc encodePrune*(prune: ControlPrune, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, prune.topicID)) - -proc decodePrune*(pb: var ProtoBuffer): seq[ControlPrune] {.gcsafe.} = - trace "decoding prune msg" - while true: - var topic: string - if pb.getString(1, topic) < 0: - break - - trace "read topic field from prune msg", topicID = topic - result.add(ControlPrune(topicID: topic)) - -proc encodeIHave*(ihave: ControlIHave, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, ihave.topicID)) +proc write*(pb: var ProtoBuffer, field: int, ihave: ControlIHave) = + var ipb = initProtoBuffer() + ipb.write(1, ihave.topicID) for mid in ihave.messageIDs: - pb.write(initProtoField(2, mid)) + ipb.write(2, mid) + ipb.finish() + pb.write(field, ipb) -proc decodeIHave*(pb: var ProtoBuffer): seq[ControlIHave] {.gcsafe.} = - trace "decoding ihave msg" - - while true: - var control: ControlIHave - if pb.getString(1, control.topicID) < 0: - trace "topic field missing from ihave msg" - break - - trace "read topic field", topicID = control.topicID - - while true: - var mid: string - if pb.getString(2, mid) < 0: - break - trace "read messageID field", mid = mid - control.messageIDs.add(mid) - - result.add(control) - -proc encodeIWant*(iwant: ControlIWant, pb: var ProtoBuffer) {.gcsafe.} = +proc write*(pb: var ProtoBuffer, field: int, iwant: ControlIWant) = + var ipb = initProtoBuffer() for mid in iwant.messageIDs: - pb.write(initProtoField(1, mid)) + ipb.write(1, mid) + if len(ipb.buffer) > 0: + ipb.finish() + pb.write(field, ipb) -proc decodeIWant*(pb: var ProtoBuffer): seq[ControlIWant] {.gcsafe.} = - trace "decoding iwant msg" +proc write*(pb: var ProtoBuffer, field: int, control: ControlMessage) = + var ipb = initProtoBuffer() + for ihave in control.ihave: + ipb.write(1, ihave) + for iwant in control.iwant: + ipb.write(2, iwant) + for graft in control.graft: + ipb.write(3, graft) + for prune in control.prune: + ipb.write(4, prune) + if len(ipb.buffer) > 0: + ipb.finish() + pb.write(field, ipb) - var control: ControlIWant - while true: - var mid: string - if pb.getString(1, mid) < 0: - break - control.messageIDs.add(mid) - trace "read messageID field", mid = mid - result.add(control) - -proc encodeControl*(control: ControlMessage, pb: var ProtoBuffer) {.gcsafe.} = - if control.ihave.len > 0: - var ihave = initProtoBuffer() - for h in control.ihave: - h.encodeIHave(ihave) - - # write messages to protobuf - if ihave.buffer.len > 0: - ihave.finish() - pb.write(initProtoField(1, ihave)) - - if control.iwant.len > 0: - var iwant = initProtoBuffer() - for w in control.iwant: - w.encodeIWant(iwant) - - # write messages to protobuf - if iwant.buffer.len > 0: - iwant.finish() - pb.write(initProtoField(2, iwant)) - - if control.graft.len > 0: - var graft = initProtoBuffer() - for g in control.graft: - g.encodeGraft(graft) - - # write messages to protobuf - if graft.buffer.len > 0: - graft.finish() - pb.write(initProtoField(3, graft)) - - if control.prune.len > 0: - var prune = initProtoBuffer() - for p in control.prune: - p.encodePrune(prune) - - # write messages to protobuf - if prune.buffer.len > 0: - prune.finish() - pb.write(initProtoField(4, prune)) - -proc decodeControl*(pb: var ProtoBuffer): Option[ControlMessage] {.gcsafe.} = - trace "decoding control submessage" - var control: ControlMessage - while true: - var field = pb.enterSubMessage() - trace "processing submessage", field = field - case field: - of 0: - trace "no submessage found in Control msg" - break - of 1: - control.ihave &= pb.decodeIHave() - of 2: - control.iwant &= pb.decodeIWant() - of 3: - control.graft &= pb.decodeGraft() - of 4: - control.prune &= pb.decodePrune() - else: - raise newException(CatchableError, "message type not recognized") - - if result.isNone: - result = some(control) - -proc encodeSubs*(subs: SubOpts, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, subs.subscribe)) - pb.write(initProtoField(2, subs.topic)) - -proc decodeSubs*(pb: var ProtoBuffer): seq[SubOpts] {.gcsafe.} = - while true: - var subOpt: SubOpts - var subscr: uint - discard pb.getVarintValue(1, subscr) - subOpt.subscribe = cast[bool](subscr) - trace "read subscribe field", subscribe = subOpt.subscribe - - if pb.getString(2, subOpt.topic) < 0: - break - trace "read subscribe field", topicName = subOpt.topic - - result.add(subOpt) - - trace "got subscriptions", subscriptions = result - -proc encodeMessage*(msg: Message, pb: var ProtoBuffer) {.gcsafe.} = - pb.write(initProtoField(1, msg.fromPeer.getBytes())) - pb.write(initProtoField(2, msg.data)) - pb.write(initProtoField(3, msg.seqno)) - - for t in msg.topicIDs: - pb.write(initProtoField(4, t)) - - if msg.signature.len > 0: - pb.write(initProtoField(5, msg.signature)) - - if msg.key.len > 0: - pb.write(initProtoField(6, msg.key)) +proc write*(pb: var ProtoBuffer, field: int, subs: SubOpts) = + var ipb = initProtoBuffer() + ipb.write(1, uint64(subs.subscribe)) + ipb.write(2, subs.topic) + ipb.finish() + pb.write(field, ipb) +proc encodeMessage*(msg: Message): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, msg.fromPeer) + pb.write(2, msg.data) + pb.write(3, msg.seqno) + for topic in msg.topicIDs: + pb.write(4, topic) + if len(msg.signature) > 0: + pb.write(5, msg.signature) + if len(msg.key) > 0: + pb.write(6, msg.key) pb.finish() + pb.buffer -proc decodeMessages*(pb: var ProtoBuffer): seq[Message] {.gcsafe.} = - # TODO: which of this fields are really optional? - while true: - var msg: Message - var fromPeer: seq[byte] - if pb.getBytes(1, fromPeer) < 0: - break - try: - msg.fromPeer = PeerID.init(fromPeer).tryGet() - except CatchableError as err: - debug "Invalid fromPeer in message", msg = err.msg - break +proc write*(pb: var ProtoBuffer, field: int, msg: Message) = + pb.write(field, encodeMessage(msg)) - trace "read message field", fromPeer = msg.fromPeer.pretty +proc decodeGraft*(pb: ProtoBuffer): ControlGraft {.inline.} = + trace "decodeGraft: decoding message" + var control = ControlGraft() + var topicId: string + if pb.getField(1, topicId): + control.topicId = topicId + trace "decodeGraft: read topicId", topic_id = topicId + else: + trace "decodeGraft: topicId is missing" + control - if pb.getBytes(2, msg.data) < 0: - break - trace "read message field", data = msg.data.shortLog +proc decodePrune*(pb: ProtoBuffer): ControlPrune {.inline.} = + trace "decodePrune: decoding message" + var control = ControlPrune() + var topicId: string + if pb.getField(1, topicId): + control.topicId = topicId + trace "decodePrune: read topicId", topic_id = topicId + else: + trace "decodePrune: topicId is missing" + control - if pb.getBytes(3, msg.seqno) < 0: - break - trace "read message field", seqno = msg.seqno.shortLog +proc decodeIHave*(pb: ProtoBuffer): ControlIHave {.inline.} = + trace "decodeIHave: decoding message" + var control = ControlIHave() + var topicId: string + if pb.getField(1, topicId): + control.topicId = topicId + trace "decodeIHave: read topicId", topic_id = topicId + else: + trace "decodeIHave: topicId is missing" + if pb.getRepeatedField(2, control.messageIDs): + trace "decodeIHave: read messageIDs", message_ids = control.messageIDs + else: + trace "decodeIHave: no messageIDs" + control - var topic: string - while true: - if pb.getString(4, topic) < 0: - break - msg.topicIDs.add(topic) - trace "read message field", topicName = topic - topic = "" +proc decodeIWant*(pb: ProtoBuffer): ControlIWant {.inline.} = + trace "decodeIWant: decoding message" + var control = ControlIWant() + if pb.getRepeatedField(1, control.messageIDs): + trace "decodeIWant: read messageIDs", message_ids = control.messageIDs + else: + trace "decodeIWant: no messageIDs" - discard pb.getBytes(5, msg.signature) - trace "read message field", signature = msg.signature.shortLog +proc decodeControl*(pb: ProtoBuffer): Option[ControlMessage] {.inline.} = + trace "decodeControl: decoding message" + var buffer: seq[byte] + if pb.getField(3, buffer): + var control: ControlMessage + var cpb = initProtoBuffer(buffer) + var ihavepbs: seq[seq[byte]] + var iwantpbs: seq[seq[byte]] + var graftpbs: seq[seq[byte]] + var prunepbs: seq[seq[byte]] - discard pb.getBytes(6, msg.key) - trace "read message field", key = msg.key.shortLog + discard cpb.getRepeatedField(1, ihavepbs) + discard cpb.getRepeatedField(2, iwantpbs) + discard cpb.getRepeatedField(3, graftpbs) + discard cpb.getRepeatedField(4, prunepbs) - result.add(msg) + for item in ihavepbs: + control.ihave.add(decodeIHave(initProtoBuffer(item))) + for item in iwantpbs: + control.iwant.add(decodeIWant(initProtoBuffer(item))) + for item in graftpbs: + control.graft.add(decodeGraft(initProtoBuffer(item))) + for item in prunepbs: + control.prune.add(decodePrune(initProtoBuffer(item))) -proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer {.gcsafe.} = - result = initProtoBuffer() - trace "encoding msg: ", msg = msg.shortLog + trace "decodeControl: " + some(control) + else: + none[ControlMessage]() - if msg.subscriptions.len > 0: - for s in msg.subscriptions: - var subs = initProtoBuffer() - encodeSubs(s, subs) +proc decodeSubscription*(pb: ProtoBuffer): SubOpts {.inline.} = + trace "decodeSubscription: decoding message" + var subflag: uint64 + var sub = SubOpts() + if pb.getField(1, subflag): + sub.subscribe = bool(subflag) + trace "decodeSubscription: read subscribe", subscribe = subflag + else: + trace "decodeSubscription: subscribe is missing" + if pb.getField(2, sub.topic): + trace "decodeSubscription: read topic", topic = sub.topic + else: + trace "decodeSubscription: topic is missing" - # write subscriptions to protobuf - if subs.buffer.len > 0: - subs.finish() - result.write(initProtoField(1, subs)) + sub - if msg.messages.len > 0: - var messages = initProtoBuffer() - for m in msg.messages: - encodeMessage(m, messages) +proc decodeSubscriptions*(pb: ProtoBuffer): seq[SubOpts] {.inline.} = + trace "decodeSubscriptions: decoding message" + var subpbs: seq[seq[byte]] + var subs: seq[SubOpts] + if pb.getRepeatedField(1, subpbs): + trace "decodeSubscriptions: read subscriptions", count = len(subpbs) + for item in subpbs: + let sub = decodeSubscription(initProtoBuffer(item)) + subs.add(sub) - # write messages to protobuf - if messages.buffer.len > 0: - messages.finish() - result.write(initProtoField(2, messages)) + if len(subs) == 0: + trace "decodeSubscription: no subscriptions found" - if msg.control.isSome: - var control = initProtoBuffer() - msg.control.get.encodeControl(control) + subs - # write messages to protobuf - if control.buffer.len > 0: - control.finish() - result.write(initProtoField(3, control)) +proc decodeMessage*(pb: ProtoBuffer): Message {.inline.} = + trace "decodeMessage: decoding message" + var msg: Message + if pb.getField(1, msg.fromPeer): + trace "decodeMessage: read fromPeer", fromPeer = msg.fromPeer.pretty() + else: + trace "decodeMessage: fromPeer is missing" - if result.buffer.len > 0: - result.finish() + if pb.getField(2, msg.data): + trace "decodeMessage: read data", data = msg.data.shortLog() + else: + trace "decodeMessage: data is missing" -proc decodeRpcMsg*(msg: seq[byte]): RPCMsg {.gcsafe.} = + if pb.getField(3, msg.seqno): + trace "decodeMessage: read seqno", seqno = msg.data.shortLog() + else: + trace "decodeMessage: seqno is missing" + + if pb.getRepeatedField(4, msg.topicIDs): + trace "decodeMessage: read topics", topic_ids = msg.topicIDs + else: + trace "decodeMessage: topics are missing" + + if pb.getField(5, msg.signature): + trace "decodeMessage: read signature", signature = msg.signature.shortLog() + else: + trace "decodeMessage: signature is missing" + + if pb.getField(6, msg.key): + trace "decodeMessage: read public key", key = msg.key.shortLog() + else: + trace "decodeMessage: public key is missing" + + msg + +proc decodeMessages*(pb: ProtoBuffer): seq[Message] {.inline.} = + trace "decodeMessages: decoding message" + var msgpbs: seq[seq[byte]] + var msgs: seq[Message] + if pb.getRepeatedField(2, msgpbs): + trace "decodeMessages: read messages", count = len(msgpbs) + for item in msgpbs: + let msg = decodeMessage(initProtoBuffer(item)) + msgs.add(msg) + + if len(msgs) == 0: + trace "decodeMessages: no messages found" + + msgs + +proc encodeRpcMsg*(msg: RPCMsg): ProtoBuffer = + trace "encodeRpcMsg: encoding message", msg = msg.shortLog() + var pb = initProtoBuffer() + for item in msg.subscriptions: + pb.write(1, item) + for item in msg.messages: + pb.write(2, item) + if msg.control.isSome(): + pb.write(3, msg.control.get()) + if len(pb.buffer) > 0: + pb.finish() + result = pb + +proc decodeRpcMsg*(msg: seq[byte]): RPCMsg = + trace "decodeRpcMsg: decoding message", msg = msg.shortLog() var pb = initProtoBuffer(msg) + var rpcMsg: RPCMsg + rpcMsg.messages = pb.decodeMessages() + rpcMsg.subscriptions = pb.decodeSubscriptions() + rpcMsg.control = pb.decodeControl() - while true: - # decode SubOpts array - var field = pb.enterSubMessage() - trace "processing submessage", field = field - case field: - of 0: - trace "no submessage found in RPC msg" - break - of 1: - result.subscriptions &= pb.decodeSubs() - of 2: - result.messages &= pb.decodeMessages() - of 3: - result.control = pb.decodeControl() - else: - raise newException(CatchableError, "message type not recognized") + rpcMsg diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index fa14674bb..d5398ccbb 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -430,8 +430,8 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon var libp2pProof = initProtoBuffer() - libp2pProof.write(initProtoField(1, p.localPublicKey)) - libp2pProof.write(initProtoField(2, signedPayload.getBytes())) + libp2pProof.write(1, p.localPublicKey) + libp2pProof.write(2, signedPayload.getBytes()) # data field also there but not used! libp2pProof.finish() @@ -449,9 +449,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon remoteSig: Signature remoteSigBytes: seq[byte] - if remoteProof.getLengthValue(1, remotePubKeyBytes) <= 0: + if not(remoteProof.getField(1, remotePubKeyBytes)): raise newException(NoiseHandshakeError, "Failed to deserialize remote public key bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") - if remoteProof.getLengthValue(2, remoteSigBytes) <= 0: + if not(remoteProof.getField(2, remoteSigBytes)): raise newException(NoiseHandshakeError, "Failed to deserialize remote signature bytes. (initiator: " & $initiator & ", peer: " & $conn.peerInfo.peerId & ")") if not remotePubKey.init(remotePubKeyBytes): diff --git a/tests/testminprotobuf.nim b/tests/testminprotobuf.nim new file mode 100644 index 000000000..a4fe7fead --- /dev/null +++ b/tests/testminprotobuf.nim @@ -0,0 +1,677 @@ +## Nim-Libp2p +## Copyright (c) 2018 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import unittest +import ../libp2p/protobuf/minprotobuf +import stew/byteutils, strutils + +when defined(nimHasUsed): {.used.} + +suite "MinProtobuf test suite": + const VarintVectors = [ + "0800", "0801", "08ffffffff07", "08ffffffff0f", "08ffffffffffffffff7f", + "08ffffffffffffffffff01" + ] + + const VarintValues = [ + 0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xFFFF_FFFF'u64, + 0x7FFF_FFFF_FFFF_FFFF'u64, 0xFFFF_FFFF_FFFF_FFFF'u64 + ] + + const Fixed32Vectors = [ + "0d00000000", "0d01000000", "0dffffff7f", "0dddccbbaa", "0dffffffff" + ] + + const Fixed32Values = [ + 0x0'u32, 0x1'u32, 0x7FFF_FFFF'u32, 0xAABB_CCDD'u32, 0xFFFF_FFFF'u32 + ] + + const Fixed64Vectors = [ + "090000000000000000", "090100000000000000", "09ffffff7f00000000", + "09ddccbbaa00000000", "09ffffffff00000000", "09ffffffffffffff7f", + "099988ffeeddccbbaa", "09ffffffffffffffff" + ] + + const Fixed64Values = [ + 0x0'u64, 0x1'u64, 0x7FFF_FFFF'u64, 0xAABB_CCDD'u64, 0xFFFF_FFFF'u64, + 0x7FFF_FFFF_FFFF_FFFF'u64, 0xAABB_CCDD_EEFF_8899'u64, + 0xFFFF_FFFF_FFFF_FFFF'u64 + ] + + const LengthVectors = [ + "0a00", "0a0161", "0a026162", "0a0461626364", "0a086162636465666768" + ] + + const LengthValues = [ + "", "a", "ab", "abcd", "abcdefgh" + ] + + ## This vector values was tested with `protoc` and related proto file. + + ## syntax = "proto2"; + ## message testmsg { + ## repeated uint64 d = 1 [packed=true]; + ## repeated uint64 d = 2 [packed=true]; + ## } + const PackedVarintVector = + "0a1f0001ffffffff07ffffffff0fffffffffffffffff7fffffffffffffffffff0112020001" + ## syntax = "proto2"; + ## message testmsg { + ## repeated sfixed32 d = 1 [packed=true]; + ## repeated sfixed32 d = 2 [packed=true]; + ## } + const PackedFixed32Vector = + "0a140000000001000000ffffff7fddccbbaaffffffff12080000000001000000" + ## syntax = "proto2"; + ## message testmsg { + ## repeated sfixed64 d = 1 [packed=true]; + ## repeated sfixed64 d = 2 [packed=true]; + ## } + const PackedFixed64Vector = + """0a4000000000000000000100000000000000ffffff7f00000000ddccbbaa00000000 + ffffffff00000000ffffffffffffff7f9988ffeeddccbbaaffffffffffffffff1210 + 00000000000000000100000000000000""" + + proc getVarintEncodedValue(value: uint64): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getVarintDecodedValue(data: openarray[byte]): uint64 = + var value: uint64 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value) + doAssert(res) + value + + proc getFixed32EncodedValue(value: float32): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getFixed32DecodedValue(data: openarray[byte]): uint32 = + var value: float32 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value) + doAssert(res) + cast[uint32](value) + + proc getFixed64EncodedValue(value: float64): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getFixed64DecodedValue(data: openarray[byte]): uint64 = + var value: float64 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value) + doAssert(res) + cast[uint64](value) + + proc getLengthEncodedValue(value: string): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getLengthEncodedValue(value: seq[byte]): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, value) + pb.finish() + return pb.buffer + + proc getLengthDecodedValue(data: openarray[byte]): string = + var value = newString(len(data)) + var valueLen = 0 + var pb = initProtoBuffer(data) + let res = pb.getField(1, value, valueLen) + + doAssert(res) + value.setLen(valueLen) + value + + proc isFullZero[T: byte|char](data: openarray[T]): bool = + for ch in data: + if int(ch) != 0: + return false + return true + + proc corruptHeader(data: var openarray[byte], index: int) = + var values = [3, 4, 6] + data[0] = data[0] and 0xF8'u8 + data[0] = data[0] or byte(values[index mod len(values)]) + + test "[varint] edge values test": + for i in 0 ..< len(VarintValues): + let data = getVarintEncodedValue(VarintValues[i]) + check: + toHex(data) == VarintVectors[i] + getVarintDecodedValue(data) == VarintValues[i] + + test "[varint] mixing many values with same field number test": + for i in 0 ..< len(VarintValues): + var pb = initProtoBuffer() + for k in 0 ..< len(VarintValues): + let index = (i + k + 1) mod len(VarintValues) + pb.write(1, VarintValues[index]) + pb.finish() + check getVarintDecodedValue(pb.buffer) == VarintValues[i] + + test "[varint] incorrect values test": + for i in 0 ..< len(VarintValues): + var value: uint64 + var data = getVarintEncodedValue(VarintValues[i]) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[varint] non-existent field test": + for i in 0 ..< len(VarintValues): + var value: uint64 + var data = getVarintEncodedValue(VarintValues[i]) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value) == false + value == 0'u64 + + test "[varint] corrupted header test": + for i in 0 ..< len(VarintValues): + for k in 0 ..< 3: + var value: uint64 + var data = getVarintEncodedValue(VarintValues[i]) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[varint] empty buffer test": + var value: uint64 + var pb = initProtoBuffer() + check: + pb.getField(1, value) == false + value == 0'u64 + + test "[varint] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, VarintValues[1]) + pb1.write(1, VarintValues[2]) + pb1.write(2, VarintValues[3]) + pb1.write(1, VarintValues[4]) + pb1.write(1, VarintValues[5]) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[uint64] + var fieldarr2: seq[uint64] + var fieldarr3: seq[uint64] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + fieldarr1[0] == VarintValues[1] + fieldarr1[1] == VarintValues[2] + fieldarr1[2] == VarintValues[4] + fieldarr1[3] == VarintValues[5] + fieldarr2[0] == VarintValues[3] + + test "[varint] Repeated packed field test": + var pb1 = initProtoBuffer() + pb1.writePacked(1, VarintValues) + pb1.writePacked(2, VarintValues[0 .. 1]) + pb1.finish() + check: + toHex(pb1.buffer) == PackedVarintVector + + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[uint64] + var fieldarr2: seq[uint64] + var fieldarr3: seq[uint64] + let r1 = pb2.getPackedRepeatedField(1, fieldarr1) + let r2 = pb2.getPackedRepeatedField(2, fieldarr2) + let r3 = pb2.getPackedRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 2 + len(fieldarr1) == 6 + fieldarr1[0] == VarintValues[0] + fieldarr1[1] == VarintValues[1] + fieldarr1[2] == VarintValues[2] + fieldarr1[3] == VarintValues[3] + fieldarr1[4] == VarintValues[4] + fieldarr1[5] == VarintValues[5] + fieldarr2[0] == VarintValues[0] + fieldarr2[1] == VarintValues[1] + + test "[fixed32] edge values test": + for i in 0 ..< len(Fixed32Values): + let data = getFixed32EncodedValue(cast[float32](Fixed32Values[i])) + check: + toHex(data) == Fixed32Vectors[i] + getFixed32DecodedValue(data) == Fixed32Values[i] + + test "[fixed32] mixing many values with same field number test": + for i in 0 ..< len(Fixed32Values): + var pb = initProtoBuffer() + for k in 0 ..< len(Fixed32Values): + let index = (i + k + 1) mod len(Fixed32Values) + pb.write(1, cast[float32](Fixed32Values[index])) + pb.finish() + check getFixed32DecodedValue(pb.buffer) == Fixed32Values[i] + + test "[fixed32] incorrect values test": + for i in 0 ..< len(Fixed32Values): + var value: float32 + var data = getFixed32EncodedValue(float32(Fixed32Values[i])) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed32] non-existent field test": + for i in 0 ..< len(Fixed32Values): + var value: float32 + var data = getFixed32EncodedValue(float32(Fixed32Values[i])) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value) == false + value == float32(0) + + test "[fixed32] corrupted header test": + for i in 0 ..< len(Fixed32Values): + for k in 0 ..< 3: + var value: float32 + var data = getFixed32EncodedValue(float32(Fixed32Values[i])) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed32] empty buffer test": + var value: float32 + var pb = initProtoBuffer() + check: + pb.getField(1, value) == false + value == float32(0) + + test "[fixed32] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, cast[float32](Fixed32Values[0])) + pb1.write(1, cast[float32](Fixed32Values[1])) + pb1.write(2, cast[float32](Fixed32Values[2])) + pb1.write(1, cast[float32](Fixed32Values[3])) + pb1.write(1, cast[float32](Fixed32Values[4])) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float32] + var fieldarr2: seq[float32] + var fieldarr3: seq[float32] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + cast[uint32](fieldarr1[0]) == Fixed64Values[0] + cast[uint32](fieldarr1[1]) == Fixed64Values[1] + cast[uint32](fieldarr1[2]) == Fixed64Values[3] + cast[uint32](fieldarr1[3]) == Fixed64Values[4] + cast[uint32](fieldarr2[0]) == Fixed64Values[2] + + test "[fixed32] Repeated packed field test": + var pb1 = initProtoBuffer() + var values = newSeq[float32](len(Fixed32Values)) + for i in 0 ..< len(values): + values[i] = cast[float32](Fixed32Values[i]) + pb1.writePacked(1, values) + pb1.writePacked(2, values[0 .. 1]) + pb1.finish() + check: + toHex(pb1.buffer) == PackedFixed32Vector + + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float32] + var fieldarr2: seq[float32] + var fieldarr3: seq[float32] + let r1 = pb2.getPackedRepeatedField(1, fieldarr1) + let r2 = pb2.getPackedRepeatedField(2, fieldarr2) + let r3 = pb2.getPackedRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 2 + len(fieldarr1) == 5 + cast[uint32](fieldarr1[0]) == Fixed32Values[0] + cast[uint32](fieldarr1[1]) == Fixed32Values[1] + cast[uint32](fieldarr1[2]) == Fixed32Values[2] + cast[uint32](fieldarr1[3]) == Fixed32Values[3] + cast[uint32](fieldarr1[4]) == Fixed32Values[4] + cast[uint32](fieldarr2[0]) == Fixed32Values[0] + cast[uint32](fieldarr2[1]) == Fixed32Values[1] + + test "[fixed64] edge values test": + for i in 0 ..< len(Fixed64Values): + let data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + check: + toHex(data) == Fixed64Vectors[i] + getFixed64DecodedValue(data) == Fixed64Values[i] + + test "[fixed64] mixing many values with same field number test": + for i in 0 ..< len(Fixed64Values): + var pb = initProtoBuffer() + for k in 0 ..< len(Fixed64Values): + let index = (i + k + 1) mod len(Fixed64Values) + pb.write(1, cast[float64](Fixed64Values[index])) + pb.finish() + check getFixed64DecodedValue(pb.buffer) == Fixed64Values[i] + + test "[fixed64] incorrect values test": + for i in 0 ..< len(Fixed64Values): + var value: float32 + var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed64] non-existent field test": + for i in 0 ..< len(Fixed64Values): + var value: float64 + var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value) == false + value == float64(0) + + test "[fixed64] corrupted header test": + for i in 0 ..< len(Fixed64Values): + for k in 0 ..< 3: + var value: float64 + var data = getFixed64EncodedValue(cast[float64](Fixed64Values[i])) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value) == false + + test "[fixed64] empty buffer test": + var value: float64 + var pb = initProtoBuffer() + check: + pb.getField(1, value) == false + value == float64(0) + + test "[fixed64] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, cast[float64](Fixed64Values[2])) + pb1.write(1, cast[float64](Fixed64Values[3])) + pb1.write(2, cast[float64](Fixed64Values[4])) + pb1.write(1, cast[float64](Fixed64Values[5])) + pb1.write(1, cast[float64](Fixed64Values[6])) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float64] + var fieldarr2: seq[float64] + var fieldarr3: seq[float64] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + cast[uint64](fieldarr1[0]) == Fixed64Values[2] + cast[uint64](fieldarr1[1]) == Fixed64Values[3] + cast[uint64](fieldarr1[2]) == Fixed64Values[5] + cast[uint64](fieldarr1[3]) == Fixed64Values[6] + cast[uint64](fieldarr2[0]) == Fixed64Values[4] + + test "[fixed64] Repeated packed field test": + var pb1 = initProtoBuffer() + var values = newSeq[float64](len(Fixed64Values)) + for i in 0 ..< len(values): + values[i] = cast[float64](Fixed64Values[i]) + pb1.writePacked(1, values) + pb1.writePacked(2, values[0 .. 1]) + pb1.finish() + let expect = PackedFixed64Vector.multiReplace(("\n", ""), (" ", "")) + check: + toHex(pb1.buffer) == expect + + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[float64] + var fieldarr2: seq[float64] + var fieldarr3: seq[float64] + let r1 = pb2.getPackedRepeatedField(1, fieldarr1) + let r2 = pb2.getPackedRepeatedField(2, fieldarr2) + let r3 = pb2.getPackedRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 2 + len(fieldarr1) == 8 + cast[uint64](fieldarr1[0]) == Fixed64Values[0] + cast[uint64](fieldarr1[1]) == Fixed64Values[1] + cast[uint64](fieldarr1[2]) == Fixed64Values[2] + cast[uint64](fieldarr1[3]) == Fixed64Values[3] + cast[uint64](fieldarr1[4]) == Fixed64Values[4] + cast[uint64](fieldarr1[5]) == Fixed64Values[5] + cast[uint64](fieldarr1[6]) == Fixed64Values[6] + cast[uint64](fieldarr1[7]) == Fixed64Values[7] + cast[uint64](fieldarr2[0]) == Fixed64Values[0] + cast[uint64](fieldarr2[1]) == Fixed64Values[1] + + test "[length] edge values test": + for i in 0 ..< len(LengthValues): + let data1 = getLengthEncodedValue(LengthValues[i]) + let data2 = getLengthEncodedValue(cast[seq[byte]](LengthValues[i])) + check: + toHex(data1) == LengthVectors[i] + toHex(data2) == LengthVectors[i] + check: + getLengthDecodedValue(data1) == LengthValues[i] + getLengthDecodedValue(data2) == LengthValues[i] + + test "[length] mixing many values with same field number test": + for i in 0 ..< len(LengthValues): + var pb1 = initProtoBuffer() + var pb2 = initProtoBuffer() + for k in 0 ..< len(LengthValues): + let index = (i + k + 1) mod len(LengthValues) + pb1.write(1, LengthValues[index]) + pb2.write(1, cast[seq[byte]](LengthValues[index])) + pb1.finish() + pb2.finish() + check getLengthDecodedValue(pb1.buffer) == LengthValues[i] + check getLengthDecodedValue(pb2.buffer) == LengthValues[i] + + test "[length] incorrect values test": + for i in 0 ..< len(LengthValues): + var value = newSeq[byte](len(LengthValues[i])) + var valueLen = 0 + var data = getLengthEncodedValue(LengthValues[i]) + # corrupting + data.setLen(len(data) - 1) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value, valueLen) == false + + test "[length] non-existent field test": + for i in 0 ..< len(LengthValues): + var value = newSeq[byte](len(LengthValues[i])) + var valueLen = 0 + var data = getLengthEncodedValue(LengthValues[i]) + var pb = initProtoBuffer(data) + check: + pb.getField(2, value, valueLen) == false + valueLen == 0 + + test "[length] corrupted header test": + for i in 0 ..< len(LengthValues): + for k in 0 ..< 3: + var value = newSeq[byte](len(LengthValues[i])) + var valueLen = 0 + var data = getLengthEncodedValue(LengthValues[i]) + data.corruptHeader(k) + var pb = initProtoBuffer(data) + check: + pb.getField(1, value, valueLen) == false + + test "[length] empty buffer test": + var value = newSeq[byte](len(LengthValues[0])) + var valueLen = 0 + var pb = initProtoBuffer() + check: + pb.getField(1, value, valueLen) == false + valueLen == 0 + + test "[length] buffer overflow test": + for i in 1 ..< len(LengthValues): + let data = getLengthEncodedValue(LengthValues[i]) + + var value = newString(len(LengthValues[i]) - 1) + var valueLen = 0 + var pb = initProtoBuffer(data) + check: + pb.getField(1, value, valueLen) == false + valueLen == len(LengthValues[i]) + isFullZero(value) == true + + test "[length] mix of buffer overflow and normal fields test": + var pb1 = initProtoBuffer() + pb1.write(1, "TEST10") + pb1.write(1, "TEST20") + pb1.write(1, "TEST") + pb1.write(1, "TEST30") + pb1.write(1, "SOME") + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var value = newString(4) + var valueLen = 0 + check: + pb2.getField(1, value, valueLen) == true + value == "SOME" + + test "[length] too big message test": + var pb1 = initProtoBuffer() + var bigString = newString(MaxMessageSize + 1) + + for i in 0 ..< len(bigString): + bigString[i] = 'A' + pb1.write(1, bigString) + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var value = newString(MaxMessageSize + 1) + var valueLen = 0 + check: + pb2.getField(1, value, valueLen) == false + + test "[length] Repeated field test": + var pb1 = initProtoBuffer() + pb1.write(1, "TEST1") + pb1.write(1, "TEST2") + pb1.write(2, "TEST5") + pb1.write(1, "TEST3") + pb1.write(1, "TEST4") + pb1.finish() + var pb2 = initProtoBuffer(pb1.buffer) + var fieldarr1: seq[seq[byte]] + var fieldarr2: seq[seq[byte]] + var fieldarr3: seq[seq[byte]] + let r1 = pb2.getRepeatedField(1, fieldarr1) + let r2 = pb2.getRepeatedField(2, fieldarr2) + let r3 = pb2.getRepeatedField(3, fieldarr3) + check: + r1 == true + r2 == true + r3 == false + len(fieldarr3) == 0 + len(fieldarr2) == 1 + len(fieldarr1) == 4 + cast[string](fieldarr1[0]) == "TEST1" + cast[string](fieldarr1[1]) == "TEST2" + cast[string](fieldarr1[2]) == "TEST3" + cast[string](fieldarr1[3]) == "TEST4" + cast[string](fieldarr2[0]) == "TEST5" + + test "Different value types in one message with same field number test": + proc getEncodedValue(): seq[byte] = + var pb = initProtoBuffer() + pb.write(1, VarintValues[1]) + pb.write(2, cast[float32](Fixed32Values[1])) + pb.write(3, cast[float64](Fixed64Values[1])) + pb.write(4, LengthValues[1]) + + pb.write(1, VarintValues[2]) + pb.write(2, cast[float32](Fixed32Values[2])) + pb.write(3, cast[float64](Fixed64Values[2])) + pb.write(4, LengthValues[2]) + + pb.write(1, cast[float32](Fixed32Values[3])) + pb.write(2, cast[float64](Fixed64Values[3])) + pb.write(3, LengthValues[3]) + pb.write(4, VarintValues[3]) + + pb.write(1, cast[float64](Fixed64Values[4])) + pb.write(2, LengthValues[4]) + pb.write(3, VarintValues[4]) + pb.write(4, cast[float32](Fixed32Values[4])) + + pb.write(1, VarintValues[1]) + pb.write(2, cast[float32](Fixed32Values[1])) + pb.write(3, cast[float64](Fixed64Values[1])) + pb.write(4, LengthValues[1]) + pb.finish() + pb.buffer + + let msg = getEncodedValue() + let pb = initProtoBuffer(msg) + var varintValue: uint64 + var fixed32Value: float32 + var fixed64Value: float64 + var lengthValue = newString(10) + var lengthSize: int + + check: + pb.getField(1, varintValue) == true + pb.getField(2, fixed32Value) == true + pb.getField(3, fixed64Value) == true + pb.getField(4, lengthValue, lengthSize) == true + + lengthValue.setLen(lengthSize) + + check: + varintValue == VarintValues[1] + cast[uint32](fixed32Value) == Fixed32Values[1] + cast[uint64](fixed64Value) == Fixed64Values[1] + lengthValue == LengthValues[1] diff --git a/tests/testnative.nim b/tests/testnative.nim index 53e4a244e..9f75ffec6 100644 --- a/tests/testnative.nim +++ b/tests/testnative.nim @@ -1,4 +1,5 @@ import testvarint, + testminprotobuf, teststreamseq import testrsa, From fcda0f6ce1cc2573e9129dbd7b37ccde8e657008 Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni Date: Mon, 13 Jul 2020 22:32:38 +0900 Subject: [PATCH 15/23] PubSubPeer tables refactor (#263) * refactor peer tables * tests fixing * override PubSubPeer equality * fix pubsubpeer comparison --- libp2p/protocols/pubsub/floodsub.nim | 23 +++-- libp2p/protocols/pubsub/gossipsub.nim | 114 +++++++++++-------------- libp2p/protocols/pubsub/pubsub.nim | 24 ++++-- libp2p/protocols/pubsub/pubsubpeer.nim | 29 +++++++ tests/pubsub/testfloodsub.nim | 2 +- tests/pubsub/testgossipinternal.nim | 110 ++++++++++++------------ tests/pubsub/testgossipsub.nim | 16 ++-- 7 files changed, 175 insertions(+), 143 deletions(-) diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index c83c2664d..cf43b70f4 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -26,7 +26,7 @@ const FloodSubCodec* = "/floodsub/1.0.0" type FloodSub* = ref object of PubSub - floodsub*: Table[string, HashSet[string]] # topic to remote peer map + floodsub*: PeerTable # topic to remote peer map seen*: TimedCache[string] # list of messages forwarded to peers method subscribeTopic*(f: FloodSub, @@ -35,23 +35,28 @@ method subscribeTopic*(f: FloodSub, peerId: string) {.gcsafe, async.} = await procCall PubSub(f).subscribeTopic(topic, subscribe, peerId) + let peer = f.peers.getOrDefault(peerId) + if peer == nil: + debug "subscribeTopic on a nil peer!" + return + if topic notin f.floodsub: - f.floodsub[topic] = initHashSet[string]() + f.floodsub[topic] = initHashSet[PubSubPeer]() if subscribe: - trace "adding subscription for topic", peer = peerId, name = topic + trace "adding subscription for topic", peer = peer.id, name = topic # subscribe the peer to the topic - f.floodsub[topic].incl(peerId) + f.floodsub[topic].incl(peer) else: - trace "removing subscription for topic", peer = peerId, name = topic + trace "removing subscription for topic", peer = peer.id, name = topic # unsubscribe the peer from the topic - f.floodsub[topic].excl(peerId) + f.floodsub[topic].excl(peer) method handleDisconnect*(f: FloodSub, peer: PubSubPeer) = ## handle peer disconnects for t in toSeq(f.floodsub.keys): if t in f.floodsub: - f.floodsub[t].excl(peer.id) + f.floodsub[t].excl(peer) procCall PubSub(f).handleDisconnect(peer) @@ -62,7 +67,7 @@ method rpcHandler*(f: FloodSub, for m in rpcMsgs: # for all RPC messages if m.messages.len > 0: # if there are any messages - var toSendPeers: HashSet[string] = initHashSet[string]() + var toSendPeers = initHashSet[PubSubPeer]() for msg in m.messages: # for every message let msgId = f.msgIdProvider(msg) logScope: msgId @@ -158,6 +163,6 @@ method initPubSub*(f: FloodSub) = procCall PubSub(f).initPubSub() f.peers = initTable[string, PubSubPeer]() f.topics = initTable[string, Topic]() - f.floodsub = initTable[string, HashSet[string]]() + f.floodsub = initTable[string, HashSet[PubSubPeer]]() f.seen = newTimedCache[string](2.minutes) f.init() diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 32240bc75..0ec1795a6 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -45,9 +45,9 @@ const GossipSubFanoutTTL* = 1.minutes type GossipSub* = ref object of FloodSub - mesh*: Table[string, HashSet[string]] # peers that we send messages to when we are subscribed to the topic - fanout*: Table[string, HashSet[string]] # peers that we send messages to when we're not subscribed to the topic - gossipsub*: Table[string, HashSet[string]] # peers that are subscribed to a topic + mesh*: PeerTable # peers that we send messages to when we are subscribed to the topic + fanout*: PeerTable # peers that we send messages to when we're not subscribed to the topic + gossipsub*: PeerTable # peers that are subscribed to a topic lastFanoutPubSub*: Table[string, Moment] # last publish time for fanout topics gossip*: Table[string, seq[ControlIHave]] # pending gossip control*: Table[string, ControlMessage] # pending control messages @@ -68,23 +68,20 @@ declareGauge(libp2p_gossipsub_peers_per_topic_gossipsub, "gossipsub peers per topic in gossipsub", labels = ["topic"]) -func addPeer( - table: var Table[string, HashSet[string]], topic: string, - peerId: string): bool = +func addPeer(table: var PeerTable, topic: string, peer: PubSubPeer): bool = # returns true if the peer was added, false if it was already in the collection - not table.mgetOrPut(topic, initHashSet[string]()).containsOrIncl(peerId) + not table.mgetOrPut(topic, initHashSet[PubSubPeer]()).containsOrIncl(peer) -func removePeer( - table: var Table[string, HashSet[string]], topic, peerId: string) = +func removePeer(table: var PeerTable, topic: string, peer: PubSubPeer) = table.withValue(topic, peers): - peers[].excl(peerId) + peers[].excl(peer) if peers[].len == 0: table.del(topic) -func hasPeer(table: Table[string, HashSet[string]], topic, peerId: string): bool = - (topic in table) and (peerId in table[topic]) +func hasPeer(table: PeerTable, topic: string, peer: PubSubPeer): bool = + (topic in table) and (peer in table[topic]) -func peers(table: Table[string, HashSet[string]], topic: string): int = +func peers(table: PeerTable, topic: string): int = if topic in table: table[topic].len else: @@ -112,8 +109,8 @@ proc replenishFanout(g: GossipSub, topic: string) = if g.fanout.peers(topic) < GossipSubDLo: trace "replenishing fanout", peers = g.fanout.peers(topic) if topic in g.gossipsub: - for peerId in g.gossipsub[topic]: - if g.fanout.addPeer(topic, peerId): + for peer in g.gossipsub[topic]: + if g.fanout.addPeer(topic, peer): if g.fanout.peers(topic) == GossipSubD: break @@ -133,8 +130,8 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = trace "replenishing mesh", topic, peers = g.mesh.peers(topic) # replenish the mesh if we're below GossipSubDlo var newPeers = toSeq( - g.gossipsub.getOrDefault(topic, initHashSet[string]()) - - g.mesh.getOrDefault(topic, initHashSet[string]()) + g.gossipsub.getOrDefault(topic, initHashSet[PubSubPeer]()) - + g.mesh.getOrDefault(topic, initHashSet[PubSubPeer]()) ) logScope: @@ -146,19 +143,11 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = trace "getting peers", topic, peers = newPeers.len - for id in newPeers: - if g.mesh.peers(topic) >= GossipSubD: - break - - let p = g.peers.getOrDefault(id) - if p != nil: - # send a graft message to the peer - grafts.add p - discard g.mesh.addPeer(topic, id) - trace "got peer", peer = id - else: - # Peer should have been removed from mesh also! - warn "Unknown peer in mesh", peer = id + for peer in newPeers: + # send a graft message to the peer + grafts.add peer + discard g.mesh.addPeer(topic, peer) + trace "got peer", peer = peer.id if g.mesh.peers(topic) > GossipSubDhi: # prune peers if we've gone over @@ -166,17 +155,14 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = shuffle(mesh) trace "about to prune mesh", mesh = mesh.len - for id in mesh: + for peer in mesh: if g.mesh.peers(topic) <= GossipSubD: break trace "pruning peers", peers = g.mesh.peers(topic) # send a graft message to the peer - g.mesh.removePeer(topic, id) - - let p = g.peers.getOrDefault(id) - if p != nil: - prunes.add(p) + g.mesh.removePeer(topic, peer) + prunes.add(peer) libp2p_gossipsub_peers_per_topic_gossipsub .set(g.gossipsub.peers(topic).int64, labelValues = [topic]) @@ -236,18 +222,18 @@ proc getGossipPeers(g: GossipSub): Table[string, ControlMessage] {.gcsafe.} = trace "topic not in gossip array, skipping", topicID = topic continue - for id in allPeers: + for peer in allPeers: if result.len >= GossipSubD: trace "got gossip peers", peers = result.len break - if id in gossipPeers: + if peer in gossipPeers: continue - if id notin result: - result[id] = controlMsg + if peer.id notin result: + result[peer.id] = controlMsg - result[id].ihave.add(ihave) + result[peer.id].ihave.add(ihave) proc heartbeat(g: GossipSub) {.async.} = while g.heartbeatRunning: @@ -282,19 +268,19 @@ method handleDisconnect*(g: GossipSub, peer: PubSubPeer) = ## handle peer disconnects procCall FloodSub(g).handleDisconnect(peer) for t in toSeq(g.gossipsub.keys): - g.gossipsub.removePeer(t, peer.id) + g.gossipsub.removePeer(t, peer) libp2p_gossipsub_peers_per_topic_gossipsub .set(g.gossipsub.peers(t).int64, labelValues = [t]) for t in toSeq(g.mesh.keys): - g.mesh.removePeer(t, peer.id) + g.mesh.removePeer(t, peer) libp2p_gossipsub_peers_per_topic_mesh .set(g.mesh.peers(t).int64, labelValues = [t]) for t in toSeq(g.fanout.keys): - g.fanout.removePeer(t, peer.id) + g.fanout.removePeer(t, peer) libp2p_gossipsub_peers_per_topic_fanout .set(g.fanout.peers(t).int64, labelValues = [t]) @@ -310,16 +296,21 @@ method subscribeTopic*(g: GossipSub, peerId: string) {.gcsafe, async.} = await procCall FloodSub(g).subscribeTopic(topic, subscribe, peerId) + let peer = g.peers.getOrDefault(peerId) + if peer == nil: + debug "subscribeTopic on a nil peer!" + return + if subscribe: trace "adding subscription for topic", peer = peerId, name = topic # subscribe remote peer to the topic - discard g.gossipsub.addPeer(topic, peerId) + discard g.gossipsub.addPeer(topic, peer) else: trace "removing subscription for topic", peer = peerId, name = topic # unsubscribe remote peer from the topic - g.gossipsub.removePeer(topic, peerId) - g.mesh.removePeer(topic, peerId) - g.fanout.removePeer(topic, peerId) + g.gossipsub.removePeer(topic, peer) + g.mesh.removePeer(topic, peer) + g.fanout.removePeer(topic, peer) libp2p_gossipsub_peers_per_topic_mesh .set(g.mesh.peers(topic).int64, labelValues = [topic]) @@ -338,10 +329,9 @@ method subscribeTopic*(g: GossipSub, proc handleGraft(g: GossipSub, peer: PubSubPeer, grafts: seq[ControlGraft]): seq[ControlPrune] = - let peerId = peer.id for graft in grafts: let topic = graft.topicID - trace "processing graft message", topic, peerId + trace "processing graft message", topic, peer # If they send us a graft before they send us a subscribe, what should # we do? For now, we add them to mesh but don't add them to gossipsub. @@ -351,10 +341,10 @@ proc handleGraft(g: GossipSub, # In the spec, there's no mention of DHi here, but implicitly, a # peer will be removed from the mesh on next rebalance, so we don't want # this peer to push someone else out - if g.mesh.addPeer(topic, peerId): - g.fanout.removePeer(topic, peer.id) + if g.mesh.addPeer(topic, peer): + g.fanout.removePeer(topic, peer) else: - trace "Peer already in mesh", topic, peerId + trace "Peer already in mesh", topic, peer else: result.add(ControlPrune(topicID: topic)) else: @@ -368,7 +358,7 @@ proc handlePrune(g: GossipSub, peer: PubSubPeer, prunes: seq[ControlPrune]) = trace "processing prune message", peer = peer.id, topicID = prune.topicID - g.mesh.removePeer(prune.topicID, peer.id) + g.mesh.removePeer(prune.topicID, peer) libp2p_gossipsub_peers_per_topic_mesh .set(g.mesh.peers(prune.topicID).int64, labelValues = [prune.topicID]) @@ -403,7 +393,7 @@ method rpcHandler*(g: GossipSub, for m in rpcMsgs: # for all RPC messages if m.messages.len > 0: # if there are any messages - var toSendPeers: HashSet[string] + var toSendPeers: HashSet[PubSubPeer] for msg in m.messages: # for every message let msgId = g.msgIdProvider(msg) logScope: msgId @@ -485,10 +475,8 @@ method unsubscribe*(g: GossipSub, let peers = g.mesh.getOrDefault(topic) g.mesh.del(topic) - for id in peers: - let p = g.peers.getOrDefault(id) - if p != nil: - await p.sendPrune(@[topic]) + for peer in peers: + await peer.sendPrune(@[topic]) method publish*(g: GossipSub, topic: string, @@ -497,7 +485,7 @@ method publish*(g: GossipSub, discard await procCall PubSub(g).publish(topic, data) trace "about to publish message on topic", name = topic, data = data.shortLog - var peers: HashSet[string] + var peers: HashSet[PubSubPeer] if topic.len <= 0: # data could be 0/empty return 0 @@ -578,9 +566,9 @@ method initPubSub*(g: GossipSub) = randomize() g.mcache = newMCache(GossipSubHistoryGossip, GossipSubHistoryLength) - g.mesh = initTable[string, HashSet[string]]() # meshes - topic to peer - g.fanout = initTable[string, HashSet[string]]() # fanout - topic to peer - g.gossipsub = initTable[string, HashSet[string]]()# topic to peer map of all gossipsub peers + g.mesh = initTable[string, HashSet[PubSubPeer]]() # meshes - topic to peer + g.fanout = initTable[string, HashSet[PubSubPeer]]() # fanout - topic to peer + g.gossipsub = initTable[string, HashSet[PubSubPeer]]()# topic to peer map of all gossipsub peers g.lastFanoutPubSub = initTable[string, Moment]() # last publish time for fanout topics g.gossip = initTable[string, seq[ControlIHave]]() # pending gossip g.control = initTable[string, ControlMessage]() # pending control messages diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 1a97ffab1..daa05d165 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -30,6 +30,8 @@ declareCounter(libp2p_pubsub_validation_failure, "pubsub failed validated messag declarePublicCounter(libp2p_pubsub_messages_published, "published messages", labels = ["topic"]) type + PeerTable* = Table[string, HashSet[PubSubPeer]] + SendRes = tuple[published: seq[string], failed: seq[string]] # keep private TopicHandler* = proc(topic: string, @@ -59,6 +61,16 @@ type observers: ref seq[PubSubObserver] # ref as in smart_ptr msgIdProvider*: MsgIdProvider # Turn message into message id (not nil) +proc hasPeerID*(t: PeerTable, topic, peerId: string): bool = + # unefficient but used only in tests! + let peers = t.getOrDefault(topic) + if peers.len == 0: + false + else: + let ps = toSeq(peers) + ps.any do (peer: PubSubPeer) -> bool: + peer.id == peerId + method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} = ## handle peer disconnects ## @@ -243,20 +255,16 @@ method subscribe*(p: PubSub, libp2p_pubsub_topics.inc() proc sendHelper*(p: PubSub, - sendPeers: HashSet[string], + sendPeers: HashSet[PubSubPeer], msgs: seq[Message]): Future[SendRes] {.async.} = var sent: seq[tuple[id: string, fut: Future[void]]] for sendPeer in sendPeers: # avoid sending to self - if sendPeer == p.peerInfo.id: + if sendPeer.peerInfo == p.peerInfo: continue - let peer = p.peers.getOrDefault(sendPeer) - if isNil(peer): - continue - - trace "sending messages to peer", peer = peer.id, msgs - sent.add((id: peer.id, fut: peer.send(@[RPCMsg(messages: msgs)]))) + trace "sending messages to peer", peer = sendPeer.id, msgs + sent.add((id: sendPeer.id, fut: sendPeer.send(@[RPCMsg(messages: msgs)]))) var published: seq[string] var failed: seq[string] diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index db1312e83..f5fcd1719 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -43,6 +43,35 @@ type RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.} +func hash*(p: PubSubPeer): Hash = + # int is either 32/64, so intptr basically, pubsubpeer is a ref + cast[pointer](p).hash + +func `==`*(a, b: PubSubPeer): bool = + # override equiality to support both nil and peerInfo comparisons + # this in the future will allow us to recycle refs + let + aptr = cast[pointer](a) + bptr = cast[pointer](b) + if aptr == nil: + if bptr == nil: + true + else: + false + elif bptr == nil: + false + else: + if a.peerInfo == nil: + if b.peerInfo == nil: + true + else: + false + else: + if b.peerInfo == nil: + false + else: + a.peerInfo.id == b.peerInfo.id + proc id*(p: PubSubPeer): string = p.peerInfo.id proc connected*(p: PubSubPeer): bool = diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index f6798976f..199921053 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -29,7 +29,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = var ceil = 15 let fsub = cast[FloodSub](sender.pubSub.get()) while not fsub.floodsub.hasKey(key) or - not fsub.floodsub[key].contains(receiver.peerInfo.id): + not fsub.floodsub.hasPeerID(key, receiver.peerInfo.id): await sleepAsync(100.millis) dec ceil doAssert(ceil > 0, "waitSub timeout!") diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index 2fb13d765..be0127873 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -29,18 +29,19 @@ suite "GossipSub internal": let gossipSub = newPubSub(TestGossipSub, randomPeerInfo()) let topic = "foobar" - gossipSub.mesh[topic] = initHashSet[string]() + gossipSub.mesh[topic] = initHashSet[PubSubPeer]() var conns = newSeq[Connection]() - gossipSub.gossipsub[topic] = initHashSet[string]() + gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]() for i in 0..<15: let conn = newBufferStream(noop) conns &= conn let peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].conn = conn - gossipSub.mesh[topic].incl(peerInfo.id) + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.conn = conn + gossipSub.peers[peerInfo.id] = peer + gossipSub.mesh[topic].incl(peer) check gossipSub.peers.len == 15 await gossipSub.rebalanceMesh(topic) @@ -58,19 +59,20 @@ suite "GossipSub internal": let gossipSub = newPubSub(TestGossipSub, randomPeerInfo()) let topic = "foobar" - gossipSub.mesh[topic] = initHashSet[string]() + gossipSub.mesh[topic] = initHashSet[PubSubPeer]() gossipSub.topics[topic] = Topic() # has to be in topics to rebalance - gossipSub.gossipsub[topic] = initHashSet[string]() + gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]() var conns = newSeq[Connection]() for i in 0..<15: let conn = newBufferStream(noop) conns &= conn let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].conn = conn - gossipSub.mesh[topic].incl(peerInfo.id) + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.conn = conn + gossipSub.peers[peerInfo.id] = peer + gossipSub.mesh[topic].incl(peer) check gossipSub.mesh[topic].len == 15 await gossipSub.rebalanceMesh(topic) @@ -91,7 +93,7 @@ suite "GossipSub internal": discard let topic = "foobar" - gossipSub.gossipsub[topic] = initHashSet[string]() + gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]() var conns = newSeq[Connection]() for i in 0..<15: @@ -99,9 +101,9 @@ suite "GossipSub internal": conns &= conn var peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler - gossipSub.gossipsub[topic].incl(peerInfo.id) + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler + gossipSub.gossipsub[topic].incl(peer) check gossipSub.gossipsub[topic].len == 15 gossipSub.replenishFanout(topic) @@ -122,7 +124,7 @@ suite "GossipSub internal": discard let topic = "foobar" - gossipSub.fanout[topic] = initHashSet[string]() + gossipSub.fanout[topic] = initHashSet[PubSubPeer]() gossipSub.lastFanoutPubSub[topic] = Moment.fromNow(1.millis) await sleepAsync(5.millis) # allow the topic to expire @@ -132,9 +134,9 @@ suite "GossipSub internal": conns &= conn let peerInfo = PeerInfo.init(PrivateKey.random(ECDSA, rng[]).get()) conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler - gossipSub.fanout[topic].incl(peerInfo.id) + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler + gossipSub.fanout[topic].incl(peer) check gossipSub.fanout[topic].len == GossipSubD @@ -157,8 +159,8 @@ suite "GossipSub internal": let topic1 = "foobar1" let topic2 = "foobar2" - gossipSub.fanout[topic1] = initHashSet[string]() - gossipSub.fanout[topic2] = initHashSet[string]() + gossipSub.fanout[topic1] = initHashSet[PubSubPeer]() + gossipSub.fanout[topic2] = initHashSet[PubSubPeer]() gossipSub.lastFanoutPubSub[topic1] = Moment.fromNow(1.millis) gossipSub.lastFanoutPubSub[topic2] = Moment.fromNow(1.minutes) await sleepAsync(5.millis) # allow the topic to expire @@ -169,10 +171,10 @@ suite "GossipSub internal": conns &= conn let peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler - gossipSub.fanout[topic1].incl(peerInfo.id) - gossipSub.fanout[topic2].incl(peerInfo.id) + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler + gossipSub.fanout[topic1].incl(peer) + gossipSub.fanout[topic2].incl(peer) check gossipSub.fanout[topic1].len == GossipSubD check gossipSub.fanout[topic2].len == GossipSubD @@ -196,9 +198,9 @@ suite "GossipSub internal": discard let topic = "foobar" - gossipSub.mesh[topic] = initHashSet[string]() - gossipSub.fanout[topic] = initHashSet[string]() - gossipSub.gossipsub[topic] = initHashSet[string]() + gossipSub.mesh[topic] = initHashSet[PubSubPeer]() + gossipSub.fanout[topic] = initHashSet[PubSubPeer]() + gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]() var conns = newSeq[Connection]() # generate mesh and fanout peers @@ -207,12 +209,12 @@ suite "GossipSub internal": conns &= conn let peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler if i mod 2 == 0: - gossipSub.fanout[topic].incl(peerInfo.id) + gossipSub.fanout[topic].incl(peer) else: - gossipSub.mesh[topic].incl(peerInfo.id) + gossipSub.mesh[topic].incl(peer) # generate gossipsub (free standing) peers for i in 0..<15: @@ -220,9 +222,9 @@ suite "GossipSub internal": conns &= conn let peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler - gossipSub.gossipsub[topic].incl(peerInfo.id) + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler + gossipSub.gossipsub[topic].incl(peer) # generate messages for i in 0..5: @@ -240,8 +242,8 @@ suite "GossipSub internal": let peers = gossipSub.getGossipPeers() check peers.len == GossipSubD for p in peers.keys: - check p notin gossipSub.fanout[topic] - check p notin gossipSub.mesh[topic] + check not gossipSub.fanout.hasPeerID(topic, p) + check not gossipSub.mesh.hasPeerID(topic, p) await allFuturesThrowing(conns.mapIt(it.close())) @@ -258,20 +260,20 @@ suite "GossipSub internal": discard let topic = "foobar" - gossipSub.fanout[topic] = initHashSet[string]() - gossipSub.gossipsub[topic] = initHashSet[string]() + gossipSub.fanout[topic] = initHashSet[PubSubPeer]() + gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]() var conns = newSeq[Connection]() for i in 0..<30: let conn = newBufferStream(noop) conns &= conn let peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler if i mod 2 == 0: - gossipSub.fanout[topic].incl(peerInfo.id) + gossipSub.fanout[topic].incl(peer) else: - gossipSub.gossipsub[topic].incl(peerInfo.id) + gossipSub.gossipsub[topic].incl(peer) # generate messages for i in 0..5: @@ -300,20 +302,20 @@ suite "GossipSub internal": discard let topic = "foobar" - gossipSub.mesh[topic] = initHashSet[string]() - gossipSub.gossipsub[topic] = initHashSet[string]() + gossipSub.mesh[topic] = initHashSet[PubSubPeer]() + gossipSub.gossipsub[topic] = initHashSet[PubSubPeer]() var conns = newSeq[Connection]() for i in 0..<30: let conn = newBufferStream(noop) conns &= conn let peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler if i mod 2 == 0: - gossipSub.mesh[topic].incl(peerInfo.id) + gossipSub.mesh[topic].incl(peer) else: - gossipSub.gossipsub[topic].incl(peerInfo.id) + gossipSub.gossipsub[topic].incl(peer) # generate messages for i in 0..5: @@ -342,20 +344,20 @@ suite "GossipSub internal": discard let topic = "foobar" - gossipSub.mesh[topic] = initHashSet[string]() - gossipSub.fanout[topic] = initHashSet[string]() + gossipSub.mesh[topic] = initHashSet[PubSubPeer]() + gossipSub.fanout[topic] = initHashSet[PubSubPeer]() var conns = newSeq[Connection]() for i in 0..<30: let conn = newBufferStream(noop) conns &= conn let peerInfo = randomPeerInfo() conn.peerInfo = peerInfo - gossipSub.peers[peerInfo.id] = newPubSubPeer(peerInfo, GossipSubCodec) - gossipSub.peers[peerInfo.id].handler = handler + let peer = newPubSubPeer(peerInfo, GossipSubCodec) + peer.handler = handler if i mod 2 == 0: - gossipSub.mesh[topic].incl(peerInfo.id) + gossipSub.mesh[topic].incl(peer) else: - gossipSub.fanout[topic].incl(peerInfo.id) + gossipSub.fanout[topic].incl(peer) # generate messages for i in 0..5: diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 757429500..3bcfe6cf6 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -32,11 +32,11 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = var ceil = 15 let fsub = GossipSub(sender.pubSub.get()) while (not fsub.gossipsub.hasKey(key) or - not fsub.gossipsub[key].contains(receiver.peerInfo.id)) and + not fsub.gossipsub.hasPeerID(key, receiver.peerInfo.id)) and (not fsub.mesh.hasKey(key) or - not fsub.mesh[key].contains(receiver.peerInfo.id)) and + not fsub.mesh.hasPeerID(key, receiver.peerInfo.id)) and (not fsub.fanout.hasKey(key) or - not fsub.fanout[key].contains(receiver.peerInfo.id)): + not fsub.fanout.hasPeerID(key , receiver.peerInfo.id)): trace "waitSub sleeping..." await sleepAsync(1.seconds) dec ceil @@ -192,7 +192,7 @@ suite "GossipSub": check: "foobar" in gossip2.topics "foobar" in gossip1.gossipsub - gossip2.peerInfo.id in gossip1.gossipsub["foobar"] + gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(awaitters) @@ -236,11 +236,11 @@ suite "GossipSub": "foobar" in gossip1.gossipsub "foobar" in gossip2.gossipsub - gossip2.peerInfo.id in gossip1.gossipsub["foobar"] or - gossip2.peerInfo.id in gossip1.mesh["foobar"] + gossip1.gossipsub.hasPeerID("foobar", gossip2.peerInfo.id) or + gossip1.mesh.hasPeerID("foobar", gossip2.peerInfo.id) - gossip1.peerInfo.id in gossip2.gossipsub["foobar"] or - gossip1.peerInfo.id in gossip2.mesh["foobar"] + gossip2.gossipsub.hasPeerID("foobar", gossip1.peerInfo.id) or + gossip2.mesh.hasPeerID("foobar", gossip1.peerInfo.id) await allFuturesThrowing(nodes.mapIt(it.stop())) await allFuturesThrowing(awaitters) From c7895ccc528e119b8f5ba421082267f96683dd93 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 13 Jul 2020 16:15:27 +0200 Subject: [PATCH 16/23] metrics: fix pubsub_peers add metric --- libp2p/protocols/pubsub/pubsub.nim | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index daa05d165..cab34a209 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -144,6 +144,10 @@ proc getOrCreatePeer(p: PubSub, p.peers[peer.id] = peer peer.observers = p.observers + + # metrics + libp2p_pubsub_peers.set(p.peers.len.int64) + return peer method handleConn*(p: PubSub, From 87e58c1c8de467cb60ea01f98d2ec4a6ca22d261 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 13 Jul 2020 16:16:46 +0200 Subject: [PATCH 17/23] metrics: one more pubsub peers fix --- libp2p/protocols/pubsub/pubsub.nim | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index cab34a209..d9679fa85 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -79,8 +79,8 @@ method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.base.} = p.peers.del(peer.id) trace "peer disconnected", peer = peer.id - # metrics - libp2p_pubsub_peers.set(p.peers.len.int64) + # metrics + libp2p_pubsub_peers.set(p.peers.len.int64) proc sendSubs*(p: PubSub, peer: PubSubPeer, From 061c54d3c689809a1f48030d1a5923af00c90224 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 13 Jul 2020 17:26:05 +0200 Subject: [PATCH 18/23] logging fixes --- libp2p/protocols/pubsub/gossipsub.nim | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 0ec1795a6..45bcb8676 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -331,7 +331,7 @@ proc handleGraft(g: GossipSub, grafts: seq[ControlGraft]): seq[ControlPrune] = for graft in grafts: let topic = graft.topicID - trace "processing graft message", topic, peer + trace "processing graft message", topic, peer = peer.id # If they send us a graft before they send us a subscribe, what should # we do? For now, we add them to mesh but don't add them to gossipsub. @@ -344,7 +344,7 @@ proc handleGraft(g: GossipSub, if g.mesh.addPeer(topic, peer): g.fanout.removePeer(topic, peer) else: - trace "Peer already in mesh", topic, peer + trace "Peer already in mesh", topic, peer = peer.id else: result.add(ControlPrune(topicID: topic)) else: From 0d4c74b33a4f6c63b368d389b54522aed0f357d4 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 13 Jul 2020 18:36:49 +0200 Subject: [PATCH 19/23] comment log that can't be json-serialized --- libp2p/protocols/pubsub/gossipsub.nim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 45bcb8676..8bf6b9b44 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -510,7 +510,7 @@ method publish*(g: GossipSub, msg = Message.init(g.peerInfo, data, topic, g.sign) msgId = g.msgIdProvider(msg) - trace "created new message", msg + # trace "created new message", msg trace "publishing on topic", name = topic, peers = peers if msgId notin g.mcache: From 6620b7a00bc29ebeb43c1ebfcd52ee01d4fc2c34 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 13 Jul 2020 19:30:18 +0200 Subject: [PATCH 20/23] more comment fixes --- libp2p/protocols/pubsub/gossipsub.nim | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 8bf6b9b44..75655a6dd 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -510,9 +510,7 @@ method publish*(g: GossipSub, msg = Message.init(g.peerInfo, data, topic, g.sign) msgId = g.msgIdProvider(msg) - # trace "created new message", msg - - trace "publishing on topic", name = topic, peers = peers + trace "publishing on topic", name = topic, peers = peers, msg if msgId notin g.mcache: g.mcache.put(msgId, msg) From 76853f064ac030a06f1855e37c17b3cd18652af0 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 13 Jul 2020 19:59:49 +0200 Subject: [PATCH 21/23] json logging again --- libp2p/protocols/pubsub/gossipsub.nim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 75655a6dd..db23f422f 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -510,7 +510,7 @@ method publish*(g: GossipSub, msg = Message.init(g.peerInfo, data, topic, g.sign) msgId = g.msgIdProvider(msg) - trace "publishing on topic", name = topic, peers = peers, msg + trace "publishing on topic", name = topic, peers = peers, msg = msg.shortLog() if msgId notin g.mcache: g.mcache.put(msgId, msg) From c6c2d99907b2f86bc67b3e5bddf97ac966581c88 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 13 Jul 2020 20:19:20 +0200 Subject: [PATCH 22/23] one more log fix --- libp2p/protocols/pubsub/gossipsub.nim | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index db23f422f..708a19b4e 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -510,7 +510,8 @@ method publish*(g: GossipSub, msg = Message.init(g.peerInfo, data, topic, g.sign) msgId = g.msgIdProvider(msg) - trace "publishing on topic", name = topic, peers = peers, msg = msg.shortLog() + trace "publishing on topic", + topic, peers = peers.len, msg = msg.shortLog() if msgId notin g.mcache: g.mcache.put(msgId, msg) From b8b0a2b4bce8448b3aabc67fda96534124073c1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C8=98tefan=20Talpalaru?= Date: Tue, 14 Jul 2020 02:02:16 +0200 Subject: [PATCH 23/23] CI: build binaries with TRACE & JSON logs (#268) Also: remove unused imports. --- libp2p.nimble | 9 ++++++--- libp2p/multistream.nim | 1 - libp2p/muxers/mplex/lpchannel.nim | 2 -- libp2p/muxers/mplex/mplex.nim | 1 - libp2p/muxers/muxer.nim | 1 - libp2p/protocols/pubsub/floodsub.nim | 4 +--- libp2p/stream/chronosstream.nim | 3 +-- libp2p/switch.nim | 4 ---- libp2p/transports/tcptransport.nim | 5 +---- libp2p/transports/transport.nim | 5 ++--- tests/pubsub/testfloodsub.nim | 5 ++--- tests/testmultistream.nim | 2 +- tests/testswitch.nim | 6 ------ 13 files changed, 14 insertions(+), 34 deletions(-) diff --git a/libp2p.nimble b/libp2p.nimble index 267599028..106cd59d5 100644 --- a/libp2p.nimble +++ b/libp2p.nimble @@ -17,12 +17,15 @@ requires "nim >= 1.2.0", "stew >= 0.1.0" proc runTest(filename: string, verify: bool = true, sign: bool = true) = - var excstr = "nim c -r --opt:speed -d:debug --verbosity:0 --hints:off -d:chronicles_log_level=info" + var excstr = "nim c --opt:speed -d:debug --verbosity:0 --hints:off" excstr.add(" --warning[CaseTransition]:off --warning[ObservableStores]:off --warning[LockLevel]:off") excstr.add(" -d:libp2p_pubsub_sign=" & $sign) excstr.add(" -d:libp2p_pubsub_verify=" & $verify) - excstr.add(" tests/" & filename) - exec excstr + if verify and sign: + # build it with TRACE and JSON logs + exec excstr & " -d:chronicles_log_level=TRACE -d:chronicles_sinks:json" & " tests/" & filename + # build it again, to run it with less verbose logs + exec excstr & " -d:chronicles_log_level=INFO -r" & " tests/" & filename rmFile "tests/" & filename.toExe proc buildSample(filename: string) = diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index 4d4006838..57499da9d 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -11,7 +11,6 @@ import strutils import chronos, chronicles, stew/byteutils import stream/connection, vbuffer, - errors, protocols/protocol logScope: diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 41be7cbad..ff63946d3 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -14,8 +14,6 @@ import types, nimcrypto/utils, ../../stream/connection, ../../stream/bufferstream, - ../../utility, - ../../errors, ../../peerinfo export connection diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 69666419c..c11172da3 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -13,7 +13,6 @@ import ../muxer, ../../stream/connection, ../../stream/bufferstream, ../../utility, - ../../errors, ../../peerinfo, coder, types, diff --git a/libp2p/muxers/muxer.nim b/libp2p/muxers/muxer.nim index 2d6116037..001fbc761 100644 --- a/libp2p/muxers/muxer.nim +++ b/libp2p/muxers/muxer.nim @@ -10,7 +10,6 @@ import chronos, chronicles import ../protocols/protocol, ../stream/connection, - ../peerinfo, ../errors logScope: diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index cf43b70f4..a0fa44c64 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -15,9 +15,7 @@ import pubsub, rpc/[messages, message], ../../stream/connection, ../../peerid, - ../../peerinfo, - ../../utility, - ../../errors + ../../peerinfo logScope: topics = "floodsub" diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 4c27ff2e5..1d200d5c3 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -7,9 +7,8 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import oids import chronos, chronicles -import connection, ../utility +import connection logScope: topics = "chronosstream" diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 015eb32a7..d9e047348 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -10,7 +10,6 @@ import tables, sequtils, options, - strformat, sets, algorithm, oids @@ -20,18 +19,15 @@ import chronos, metrics import stream/connection, - stream/chronosstream, transports/transport, multistream, multiaddress, protocols/protocol, protocols/secure/secure, - protocols/secure/plaintext, # for plain text peerinfo, protocols/identify, protocols/pubsub/pubsub, muxers/muxer, - errors, peerid logScope: diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 6edb510ea..032b3e57b 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -16,9 +16,6 @@ import transport, ../stream/connection, ../stream/chronosstream -when chronicles.enabledLogLevel == LogLevel.TRACE: - import oids - logScope: topics = "tcptransport" @@ -74,7 +71,7 @@ proc connHandler*(t: TcpTransport, proc cleanup() {.async.} = try: await client.join() - trace "cleaning up client", addrs = client.remoteAddress, connoid = conn.oid + trace "cleaning up client", addrs = $client.remoteAddress, connoid = conn.oid if not(isNil(conn)): await conn.close() t.clients.keepItIf(it != client) diff --git a/libp2p/transports/transport.nim b/libp2p/transports/transport.nim index ff98c8ee8..a570d79a4 100644 --- a/libp2p/transports/transport.nim +++ b/libp2p/transports/transport.nim @@ -7,12 +7,11 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import sequtils, tables +import sequtils import chronos, chronicles import ../stream/connection, ../multiaddress, - ../multicodec, - ../errors + ../multicodec type ConnHandler* = proc (conn: Connection): Future[void] {.gcsafe.} diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index 199921053..2d21b3118 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -9,7 +9,7 @@ {.used.} -import unittest, sequtils, options, tables, sets +import unittest, sequtils, options, tables import chronos, stew/byteutils import utils, ../../libp2p/[errors, @@ -18,8 +18,7 @@ import utils, crypto/crypto, protocols/pubsub/pubsub, protocols/pubsub/floodsub, - protocols/pubsub/rpc/messages, - protocols/pubsub/rpc/message] + protocols/pubsub/rpc/messages] import ../helpers diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index b36440be4..7f944b6cf 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -1,4 +1,4 @@ -import unittest, strutils, sequtils, strformat, stew/byteutils +import unittest, strutils, strformat, stew/byteutils import chronos import ../libp2p/errors, ../libp2p/multistream, diff --git a/tests/testswitch.nim b/tests/testswitch.nim index e06212504..a486b4034 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -9,19 +9,13 @@ import ../libp2p/[errors, multistream, standard_setup, stream/bufferstream, - protocols/identify, stream/connection, - transports/transport, - transports/tcptransport, multiaddress, peerinfo, crypto/crypto, protocols/protocol, muxers/muxer, muxers/mplex/mplex, - muxers/mplex/types, - protocols/secure/secio, - protocols/secure/secure, stream/lpstream] import ./helpers