diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 30e616384..03a105ea0 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -453,6 +453,9 @@ proc validateAndRelay( g.rewardDelivered(peer, topic, true) + # trigger hooks + peer.validatedObservers(msg, msgId) + # The send list typically matches the idontwant list from above, but # might differ if validation takes time var toSendPeers = HashSet[PubSubPeer]() diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 9dd00f66a..79c5c90de 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -67,6 +67,8 @@ type PubSubObserver* = ref object onRecv*: proc(peer: PubSubPeer, msgs: var RPCMsg) {.gcsafe, raises: [].} onSend*: proc(peer: PubSubPeer, msgs: var RPCMsg) {.gcsafe, raises: [].} + onValidated*: + proc(peer: PubSubPeer, msg: Message, msgId: MessageId) {.gcsafe, raises: [].} PubSubPeerEventKind* {.pure.} = enum StreamOpened @@ -170,14 +172,23 @@ proc recvObservers*(p: PubSubPeer, msg: var RPCMsg) = if not (isNil(p.observers)) and p.observers[].len > 0: for obs in p.observers[]: if not (isNil(obs)): # TODO: should never be nil, but... - obs.onRecv(p, msg) + if not (isNil(obs.onRecv)): + obs.onRecv(p, msg) + +proc validatedObservers*(p: PubSubPeer, msg: Message, msgId: MessageId) = + # trigger hooks + if not (isNil(p.observers)) and p.observers[].len > 0: + for obs in p.observers[]: + if not (isNil(obs.onValidated)): + obs.onValidated(p, msg, msgId) proc sendObservers(p: PubSubPeer, msg: var RPCMsg) = # trigger hooks if not (isNil(p.observers)) and p.observers[].len > 0: for obs in p.observers[]: if not (isNil(obs)): # TODO: should never be nil, but... - obs.onSend(p, msg) + if not (isNil(obs.onSend)): + obs.onSend(p, msg) proc handle*(p: PubSubPeer, conn: Connection) {.async.} = debug "starting pubsub read loop", conn, peer = p, closed = conn.closed diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index caf41482c..9f059b8f8 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -220,6 +220,63 @@ suite "GossipSub": await allFuturesThrowing(nodesFut.concat()) + asyncTest "GossipSub's observers should run after message is sent, received and validated": + var + recvCounter = 0 + sendCounter = 0 + validatedCounter = 0 + + proc handler(topic: string, data: seq[byte]) {.async.} = + discard + + proc onRecv(peer: PubSubPeer, msgs: var RPCMsg) = + inc recvCounter + + proc onSend(peer: PubSubPeer, msgs: var RPCMsg) = + inc sendCounter + + proc onValidated(peer: PubSubPeer, msg: Message, msgId: MessageId) = + inc validatedCounter + + let obs0 = PubSubObserver(onSend: onSend) + let obs1 = PubSubObserver(onRecv: onRecv, onValidated: onValidated) + + let nodes = generateNodes(2, gossip = true) + # start switches + discard await allFinished(nodes[0].switch.start(), nodes[1].switch.start()) + + await subscribeNodes(nodes) + + nodes[0].addObserver(obs0) + nodes[1].addObserver(obs1) + nodes[1].subscribe("foo", handler) + nodes[1].subscribe("bar", handler) + + proc validator( + topic: string, message: Message + ): Future[ValidationResult] {.async.} = + result = if topic == "foo": ValidationResult.Accept else: ValidationResult.Reject + + nodes[1].addValidator("foo", "bar", validator) + + # Send message that will be accepted by the receiver's validator + tryPublish await nodes[0].publish("foo", "Hello!".toBytes()), 1 + + check: + recvCounter == 1 + validatedCounter == 1 + sendCounter == 1 + + # Send message that will be rejected by the receiver's validator + tryPublish await nodes[0].publish("bar", "Hello!".toBytes()), 1 + + check: + recvCounter == 2 + validatedCounter == 1 + sendCounter == 2 + + await allFuturesThrowing(nodes[0].switch.stop(), nodes[1].switch.stop()) + asyncTest "GossipSub unsub - resub faster than backoff": var handlerFut = newFuture[bool]() proc handler(topic: string, data: seq[byte]) {.async.} =