diff --git a/libp2p/protocols/pubsub/errors.nim b/libp2p/protocols/pubsub/errors.nim new file mode 100644 index 000000000..cfb2ccc6d --- /dev/null +++ b/libp2p/protocols/pubsub/errors.nim @@ -0,0 +1,6 @@ +# this module will be further extended in PR +# https://github.com/status-im/nim-libp2p/pull/107/ + +type + ValidationResult* {.pure.} = enum + Accept, Reject, Ignore diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index eda35ee2f..54d6f281b 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -96,7 +96,14 @@ method rpcHandler*(f: FloodSub, f.handleSubscribe(peer, sub.topic, sub.subscribe) for msg in rpcMsg.messages: # for every message - let msgId = f.msgIdProvider(msg) + let msgIdResult = f.msgIdProvider(msg) + if msgIdResult.isErr: + debug "Dropping message due to failed message id generation", + error = msgIdResult.error + # TODO: descore peers due to error during message validation (malicious?) + continue + + let msgId = msgIdResult.get if f.addSeen(msgId): trace "Dropping already-seen message", msgId, peer @@ -184,7 +191,14 @@ method publish*(f: FloodSub, Message.init(none(PeerInfo), data, topic, none(uint64), false) else: Message.init(some(f.peerInfo), data, topic, some(f.msgSeqno), f.sign) - msgId = f.msgIdProvider(msg) + msgIdResult = f.msgIdProvider(msg) + + if msgIdResult.isErr: + trace "Error generating message id, skipping publish", + error = msgIdResult.error + return 0 + + let msgId = msgIdResult.get trace "Created new message", msg = shortLog(msg), peers = peers.len, topic, msgId diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index d8753d4fb..9a2a574fa 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -362,8 +362,16 @@ method rpcHandler*(g: GossipSub, for i in 0.. 0 and m.fromPeer.data.len > 0: - byteutils.toHex(m.seqno) & $m.fromPeer - else: - # This part is irrelevant because it's not standard, - # We use it exclusively for testing basically and users should - # implement their own logic in the case they use anonymization - $m.data.hash & $m.topicIDs.hash - mid.toBytes() +func defaultMsgIdProvider*(m: Message): Result[MessageID, ValidationResult] = + if m.seqno.len > 0 and m.fromPeer.data.len > 0: + let mid = byteutils.toHex(m.seqno) & $m.fromPeer + ok mid.toBytes() + else: + err ValidationResult.Reject proc sign*(msg: Message, privateKey: PrivateKey): CryptoResult[seq[byte]] = ok((? privateKey.sign(PubSubPrefix & encodeMessage(msg, false))).getBytes()) diff --git a/tests/pubsub/testfloodsub.nim b/tests/pubsub/testfloodsub.nim index 0ef50800b..38d00d9cd 100644 --- a/tests/pubsub/testfloodsub.nim +++ b/tests/pubsub/testfloodsub.nim @@ -20,6 +20,7 @@ import utils, protocols/pubsub/floodsub, protocols/pubsub/rpc/messages, protocols/pubsub/peertable] +import ../../libp2p/protocols/pubsub/errors as pubsub_errors import ../helpers diff --git a/tests/pubsub/testgossipinternal.nim b/tests/pubsub/testgossipinternal.nim index d1ff8e9d4..ce4f7930a 100644 --- a/tests/pubsub/testgossipinternal.nim +++ b/tests/pubsub/testgossipinternal.nim @@ -39,6 +39,8 @@ proc randomPeerId(): PeerId = except CatchableError as exc: raise newException(Defect, exc.msg) +const MsgIdFail = "msg id gen failure" + suite "GossipSub internal": teardown: checkTrackers() @@ -308,7 +310,7 @@ suite "GossipSub internal": conn.peerId = peerId inc seqno let msg = Message.init(peerId, ("HELLO" & $i).toBytes(), topic, some(seqno)) - gossipSub.mcache.put(gossipSub.msgIdProvider(msg), msg) + gossipSub.mcache.put(gossipSub.msgIdProvider(msg).expect(MsgIdFail), msg) check gossipSub.fanout[topic].len == 15 check gossipSub.mesh[topic].len == 15 @@ -355,7 +357,7 @@ suite "GossipSub internal": conn.peerId = peerId inc seqno let msg = Message.init(peerId, ("HELLO" & $i).toBytes(), topic, some(seqno)) - gossipSub.mcache.put(gossipSub.msgIdProvider(msg), msg) + gossipSub.mcache.put(gossipSub.msgIdProvider(msg).expect(MsgIdFail), msg) let peers = gossipSub.getGossipPeers() check peers.len == gossipSub.parameters.d @@ -396,7 +398,7 @@ suite "GossipSub internal": conn.peerId = peerId inc seqno let msg = Message.init(peerId, ("HELLO" & $i).toBytes(), topic, some(seqno)) - gossipSub.mcache.put(gossipSub.msgIdProvider(msg), msg) + gossipSub.mcache.put(gossipSub.msgIdProvider(msg).expect(MsgIdFail), msg) let peers = gossipSub.getGossipPeers() check peers.len == gossipSub.parameters.d @@ -437,7 +439,7 @@ suite "GossipSub internal": conn.peerId = peerId inc seqno let msg = Message.init(peerId, ("bar" & $i).toBytes(), topic, some(seqno)) - gossipSub.mcache.put(gossipSub.msgIdProvider(msg), msg) + gossipSub.mcache.put(gossipSub.msgIdProvider(msg).expect(MsgIdFail), msg) let peers = gossipSub.getGossipPeers() check peers.len == 0 diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index a1332ffa9..11eeda65c 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -24,6 +24,7 @@ import utils, ../../libp2p/[errors, protocols/pubsub/peertable, protocols/pubsub/timedcache, protocols/pubsub/rpc/messages] +import ../../libp2p/protocols/pubsub/errors as pubsub_errors import ../helpers proc `$`(peer: PubSubPeer): string = shortLog(peer) diff --git a/tests/pubsub/testmcache.nim b/tests/pubsub/testmcache.nim index 4b7b7ce57..7ceab8165 100644 --- a/tests/pubsub/testmcache.nim +++ b/tests/pubsub/testmcache.nim @@ -5,19 +5,21 @@ import stew/byteutils import ../../libp2p/[peerid, crypto/crypto, protocols/pubsub/mcache, - protocols/pubsub/rpc/message, protocols/pubsub/rpc/messages] +import ./utils var rng = newRng() proc randomPeerId(): PeerId = PeerId.init(PrivateKey.random(ECDSA, rng[]).get()).get() +const MsgIdGenFail = "msg id gen failure" + suite "MCache": test "put/get": var mCache = MCache.init(3, 5) var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes()) - let msgId = defaultMsgIdProvider(msg) + let msgId = defaultMsgIdProvider(msg).expect(MsgIdGenFail) mCache.put(msgId, msg) check mCache.get(msgId).isSome and mCache.get(msgId).get() == msg @@ -28,13 +30,13 @@ suite "MCache": var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["foo"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) for i in 0..<5: var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["bar"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) var mids = mCache.window("foo") check mids.len == 3 @@ -49,7 +51,7 @@ suite "MCache": var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["foo"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) mCache.shift() check mCache.window("foo").len == 0 @@ -58,7 +60,7 @@ suite "MCache": var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["bar"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) mCache.shift() check mCache.window("bar").len == 0 @@ -67,7 +69,7 @@ suite "MCache": var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["baz"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) mCache.shift() check mCache.window("baz").len == 0 @@ -79,19 +81,19 @@ suite "MCache": var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["foo"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) for i in 0..<3: var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["bar"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) for i in 0..<3: var msg = Message(fromPeer: randomPeerId(), seqno: "12345".toBytes(), topicIDs: @["baz"]) - mCache.put(defaultMsgIdProvider(msg), msg) + mCache.put(defaultMsgIdProvider(msg).expect(MsgIdGenFail), msg) mCache.shift() check mCache.window("foo").len == 0 diff --git a/tests/pubsub/testmessage.nim b/tests/pubsub/testmessage.nim index d555d89f5..19e992c9c 100644 --- a/tests/pubsub/testmessage.nim +++ b/tests/pubsub/testmessage.nim @@ -3,8 +3,10 @@ import unittest2 {.used.} import options +import stew/byteutils import ../../libp2p/[peerid, peerinfo, crypto/crypto, + protocols/pubsub/errors, protocols/pubsub/rpc/message, protocols/pubsub/rpc/messages] @@ -18,3 +20,56 @@ suite "Message": msg = Message.init(some(peer), @[], "topic", some(seqno), sign = true) check verify(msg) + + test "defaultMsgIdProvider success": + let + seqno = 11'u64 + pkHex = + """08011240B9EA7F0357B5C1247E4FCB5AD09C46818ECB07318CA84711875F4C6C + E6B946186A4EB44E0D714B2A2D48263D75CF52D30BEF9D9AE2A9FEB7DAF1775F + E731065A""" + seckey = PrivateKey.init(fromHex(stripSpaces(pkHex))) + .expect("invalid private key bytes") + peer = PeerInfo.new(seckey) + msg = Message.init(some(peer), @[], "topic", some(seqno), sign = true) + msgIdResult = msg.defaultMsgIdProvider() + + check: + msgIdResult.isOk + string.fromBytes(msgIdResult.get) == + "000000000000000b12D3KooWGyLzSt9g4U9TdHYDvVWAs5Ht4WrocgoyqPxxvnqAL8qw" + + test "defaultMsgIdProvider error - no source peer id": + let + seqno = 11'u64 + pkHex = + """08011240B9EA7F0357B5C1247E4FCB5AD09C46818ECB07318CA84711875F4C6C + E6B946186A4EB44E0D714B2A2D48263D75CF52D30BEF9D9AE2A9FEB7DAF1775F + E731065A""" + seckey = PrivateKey.init(fromHex(stripSpaces(pkHex))) + .expect("invalid private key bytes") + peer = PeerInfo.new(seckey) + + var msg = Message.init(peer.some, @[], "topic", some(seqno), sign = true) + msg.fromPeer = PeerId() + let msgIdResult = msg.defaultMsgIdProvider() + + check: + msgIdResult.isErr + msgIdResult.error == ValidationResult.Reject + + test "defaultMsgIdProvider error - no source seqno": + let + pkHex = + """08011240B9EA7F0357B5C1247E4FCB5AD09C46818ECB07318CA84711875F4C6C + E6B946186A4EB44E0D714B2A2D48263D75CF52D30BEF9D9AE2A9FEB7DAF1775F + E731065A""" + seckey = PrivateKey.init(fromHex(stripSpaces(pkHex))) + .expect("invalid private key bytes") + peer = PeerInfo.new(seckey) + msg = Message.init(some(peer), @[], "topic", uint64.none, sign = true) + msgIdResult = msg.defaultMsgIdProvider() + + check: + msgIdResult.isErr + msgIdResult.error == ValidationResult.Reject diff --git a/tests/pubsub/utils.nim b/tests/pubsub/utils.nim index 50644deeb..f203f7421 100644 --- a/tests/pubsub/utils.nim +++ b/tests/pubsub/utils.nim @@ -4,24 +4,37 @@ const libp2p_pubsub_verify {.booldefine.} = true libp2p_pubsub_anonymize {.booldefine.} = false -import random, tables -import chronos +import hashes, random, tables +import chronos, stew/[byteutils, results] import ../../libp2p/[builders, + protocols/pubsub/errors, protocols/pubsub/pubsub, protocols/pubsub/gossipsub, protocols/pubsub/floodsub, + protocols/pubsub/rpc/messages, protocols/secure/secure] export builders randomize() +func defaultMsgIdProvider*(m: Message): Result[MessageID, ValidationResult] = + let mid = + if m.seqno.len > 0 and m.fromPeer.data.len > 0: + byteutils.toHex(m.seqno) & $m.fromPeer + else: + # This part is irrelevant because it's not standard, + # We use it exclusively for testing basically and users should + # implement their own logic in the case they use anonymization + $m.data.hash & $m.topicIDs.hash + ok mid.toBytes() + proc generateNodes*( num: Natural, secureManagers: openArray[SecureProtocol] = [ SecureProtocol.Noise ], - msgIdProvider: MsgIdProvider = nil, + msgIdProvider: MsgIdProvider = defaultMsgIdProvider, gossip: bool = false, triggerSelf: bool = false, verifySignature: bool = libp2p_pubsub_verify,