diff --git a/libp2p/protocols/pubsub/floodsub.nim b/libp2p/protocols/pubsub/floodsub.nim index b0c7000ba..502f4f205 100644 --- a/libp2p/protocols/pubsub/floodsub.nim +++ b/libp2p/protocols/pubsub/floodsub.nim @@ -138,6 +138,7 @@ method unsubscribe*(f: FloodSub, await f.sendSubs(p, topics.mapIt(it.topic).deduplicate(), false) method initPubSub*(f: FloodSub) = + procCall PubSub(f).initPubSub() f.peers = initTable[string, PubSubPeer]() f.topics = initTable[string, Topic]() f.floodsub = initTable[string, HashSet[string]]() diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 3b6e2dffc..3bfb1cc3d 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -16,6 +16,7 @@ import pubsubpeer, ../../peerinfo export PubSubPeer +export PubSubObserver logScope: topic = "PubSub" @@ -42,6 +43,7 @@ type sign*: bool # enable message signing cleanupLock: AsyncLock validators*: Table[string, HashSet[ValidatorHandler]] + observers: ref seq[PubSubObserver] # ref as in smart_ptr proc sendSubs*(p: PubSub, peer: PubSubPeer, @@ -72,6 +74,7 @@ method rpcHandler*(p: PubSub, rpcMsgs: seq[RPCMsg]) {.async, base.} = ## handle rpc messages trace "processing RPC message", peer = peer.id, msgs = rpcMsgs.len + for m in rpcMsgs: # for all RPC messages trace "processing messages", msg = m.shortLog if m.subscriptions.len > 0: # if there are any subscriptions @@ -104,6 +107,7 @@ proc getPeer(p: PubSub, p.peers[peer.id] = peer peer.refs.inc # increment reference cound + peer.observers = p.observers result = peer method handleConn*(p: PubSub, @@ -201,7 +205,7 @@ method publish*(p: PubSub, method initPubSub*(p: PubSub) {.base.} = ## perform pubsub initializaion - discard + p.observers = new(seq[PubSubObserver]) method start*(p: PubSub) {.async, base.} = ## start pubsub @@ -253,3 +257,10 @@ proc newPubSub*(P: typedesc[PubSub], sign: sign, cleanupLock: newAsyncLock()) result.initPubSub() + +proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer + +proc removeObserver*(p: PubSub; observer: PubSubObserver) = + let idx = p.observers[].find(observer) + if idx != -1: + p.observers[].del(idx) diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 981b42afc..8e19f7d0c 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -23,18 +23,23 @@ logScope: topic = "PubSubPeer" type - PubSubPeer* = ref object of RootObj - proto: string # the protocol that this peer joined from - sendConn: Connection - peerInfo*: PeerInfo - handler*: RPCHandler - topics*: seq[string] - sentRpcCache: TimedCache[string] # cache for already sent messages - recvdRpcCache: TimedCache[string] # cache for already received messages - refs*: int # refcount of the connections this peer is handling - onConnect: AsyncEvent + PubSubObserver* = ref object + onRecv*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe.} + onSend*: proc(peer: PubSubPeer; msgs: var RPCMsg) {.gcsafe.} - RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.} + PubSubPeer* = ref object of RootObj + proto: string # the protocol that this peer joined from + sendConn: Connection + peerInfo*: PeerInfo + handler*: RPCHandler + topics*: seq[string] + sentRpcCache: TimedCache[string] # cache for already sent messages + recvdRpcCache: TimedCache[string] # cache for already received messages + refs*: int # refcount of the connections this peer is handling + onConnect: AsyncEvent + observers*: ref seq[PubSubObserver] # ref as in smart_ptr + + RPCHandler* = proc(peer: PubSubPeer, msg: seq[RPCMsg]): Future[void] {.gcsafe.} proc id*(p: PubSubPeer): string = p.peerInfo.id @@ -58,8 +63,11 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} = trace "message already received, skipping", peer = p.id continue - let msg = decodeRpcMsg(data) + var msg = decodeRpcMsg(data) trace "decoded msg from peer", peer = p.id, msg = msg.shortLog + # trigger hooks + for obs in p.observers[]: + obs.onRecv(p, msg) await p.handler(p, @[msg]) p.recvdRpcCache.put($hexData.hash) except CatchableError as exc: @@ -71,9 +79,14 @@ proc handle*(p: PubSubPeer, conn: Connection) {.async.} = proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = try: - for m in msgs: + for m in msgs.items: trace "sending msgs to peer", toPeer = p.id let encoded = encodeRpcMsg(m) + # trigger hooks + if p.observers[].len > 0: + var mm = m + for obs in p.observers[]: + obs.onSend(p, mm) let encodedHex = encoded.buffer.toHex() if encoded.buffer.len <= 0: trace "empty message, skipping", peer = p.id diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index 8e34cf9b6..22a4cbd6c 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -268,6 +268,18 @@ suite "GossipSub": await nodes[1].subscribe("foobar", handler) await waitSub(nodes[0], nodes[1], "foobar") + + var observed = 0 + let + obs1 = PubSubObserver(onRecv: proc(peer: PubSubPeer; msgs: var RPCMsg) = + inc observed + ) + obs2 = PubSubObserver(onSend: proc(peer: PubSubPeer; msgs: var RPCMsg) = + inc observed + ) + nodes[1].pubsub.get().addObserver(obs1) + nodes[0].pubsub.get().addObserver(obs2) + await nodes[0].publish("foobar", cast[seq[byte]]("Hello!")) var gossipSub1: GossipSub = GossipSub(nodes[0].pubSub.get()) @@ -283,7 +295,7 @@ suite "GossipSub": await nodes[1].stop() await allFuturesThrowing(wait) - result = true + result = observed == 2 check: waitFor(runTests()) == true