diff --git a/packages/sdk/src/protocols/filter.ts b/packages/sdk/src/protocols/filter.ts index 18f1dc7521..eddad0a49f 100644 --- a/packages/sdk/src/protocols/filter.ts +++ b/packages/sdk/src/protocols/filter.ts @@ -24,7 +24,6 @@ import { SubscribeResult, type Unsubscribe } from "@waku/interfaces"; -import { messageHashStr } from "@waku/message-hash"; import { WakuMessage } from "@waku/proto"; import { ensurePubsubTopicIsConfigured, @@ -35,36 +34,27 @@ import { } from "@waku/utils"; import { BaseProtocolSDK } from "./base_protocol.js"; +import { FilterReliabilityMonitor as ReliabilityMonitor } from "./reliability_monitor.js"; type SubscriptionCallback = { decoders: IDecoder[]; callback: Callback; }; -type ReceivedMessageHashes = { - all: Set; - nodes: { - [peerId: PeerIdStr]: Set; - }; -}; - const log = new Logger("sdk:filter"); const DEFAULT_MAX_PINGS = 3; -const DEFAULT_MAX_MISSED_MESSAGES_THRESHOLD = 3; const DEFAULT_KEEP_ALIVE = 30 * 1000; const DEFAULT_SUBSCRIBE_OPTIONS = { keepAlive: DEFAULT_KEEP_ALIVE }; + export class SubscriptionManager implements ISubscriptionSDK { - private readonly receivedMessagesHashStr: string[] = []; private keepAliveTimer: number | null = null; - private readonly receivedMessagesHashes: ReceivedMessageHashes; private peerFailures: Map = new Map(); - private missedMessagesByPeer: Map = new Map(); private maxPingFailures: number = DEFAULT_MAX_PINGS; - private maxMissedMessagesThreshold = DEFAULT_MAX_MISSED_MESSAGES_THRESHOLD; + private reliabilityMonitor: ReliabilityMonitor; private subscriptionCallbacks: Map< ContentTopic, @@ -79,26 +69,9 @@ export class SubscriptionManager implements ISubscriptionSDK { ) { this.pubsubTopic = pubsubTopic; this.subscriptionCallbacks = new Map(); - const allPeerIdStr = this.getPeers().map((p) => p.id.toString()); - this.receivedMessagesHashes = { - all: new Set(), - nodes: { - ...Object.fromEntries(allPeerIdStr.map((peerId) => [peerId, new Set()])) - } - }; - allPeerIdStr.forEach((peerId) => this.missedMessagesByPeer.set(peerId, 0)); - } - - public get messageHashes(): string[] { - return [...this.receivedMessagesHashes.all]; - } - - private addHash(hash: string, peerIdStr?: string): void { - this.receivedMessagesHashes.all.add(hash); - - if (peerIdStr) { - this.receivedMessagesHashes.nodes[peerIdStr].add(hash); - } + this.reliabilityMonitor = new ReliabilityMonitor( + this.renewAndSubscribePeer + ); } public async subscribe( @@ -108,9 +81,6 @@ export class SubscriptionManager implements ISubscriptionSDK { ): Promise { this.keepAliveTimer = options.keepAlive || DEFAULT_KEEP_ALIVE; this.maxPingFailures = options.pingsBeforePeerRenewed || DEFAULT_MAX_PINGS; - this.maxMissedMessagesThreshold = - options.maxMissedMessagesThreshold || - DEFAULT_MAX_MISSED_MESSAGES_THRESHOLD; const decodersArray = Array.isArray(decoders) ? decoders : [decoders]; @@ -218,55 +188,17 @@ export class SubscriptionManager implements ISubscriptionSDK { return finalResult; } - private async validateMessage(): Promise { - for (const hash of this.receivedMessagesHashes.all) { - for (const [peerIdStr, hashes] of Object.entries( - this.receivedMessagesHashes.nodes - )) { - if (!hashes.has(hash)) { - this.incrementMissedMessageCount(peerIdStr); - if (this.shouldRenewPeer(peerIdStr)) { - log.info( - `Peer ${peerIdStr} has missed too many messages, renewing.` - ); - const peerId = this.getPeers().find( - (p) => p.id.toString() === peerIdStr - )?.id; - if (!peerId) { - log.error( - `Unexpected Error: Peer ${peerIdStr} not found in connected peers.` - ); - continue; - } - try { - await this.renewAndSubscribePeer(peerId); - } catch (error) { - log.error(`Failed to renew peer ${peerIdStr}: ${error}`); - } - } - } - } - } - } - public async processIncomingMessage( message: WakuMessage, peerIdStr: PeerIdStr ): Promise { - const hashedMessageStr = messageHashStr( + await this.reliabilityMonitor.processAndValidateMessage( + message, this.pubsubTopic, - message as IProtoMessage + peerIdStr, + this.getPeers ); - this.addHash(hashedMessageStr, peerIdStr); - void this.validateMessage(); - - if (this.receivedMessagesHashStr.includes(hashedMessageStr)) { - log.info("Message already received, skipping"); - return; - } - this.receivedMessagesHashStr.push(hashedMessageStr); - const { contentTopic } = message; const subscriptionCallback = this.subscriptionCallbacks.get(contentTopic); if (!subscriptionCallback) { @@ -345,10 +277,16 @@ export class SubscriptionManager implements ISubscriptionSDK { if (failures > this.maxPingFailures) { try { - await this.renewAndSubscribePeer(peerId); - this.peerFailures.delete(peerId.toString()); + const newPeer = await this.renewAndSubscribePeer(peerId); + if (newPeer) { + this.peerFailures.delete(peerId.toString()); + this.reliabilityMonitor.resetPeer(peerId.toString()); + this.reliabilityMonitor.resetPeer(newPeer.id.toString()); + } } catch (error) { - log.error(`Failed to renew peer ${peerId.toString()}: ${error}.`); + log.error( + `Failed to renew and subscribe peer ${peerId.toString()}: ${error}.` + ); } } } @@ -363,18 +301,12 @@ export class SubscriptionManager implements ISubscriptionSDK { newPeer, Array.from(this.subscriptionCallbacks.keys()) ); - - this.receivedMessagesHashes.nodes[newPeer.id.toString()] = new Set(); - this.missedMessagesByPeer.set(newPeer.id.toString(), 0); - return newPeer; } catch (error) { - log.warn(`Failed to renew peer ${peerId.toString()}: ${error}.`); - return; - } finally { - this.peerFailures.delete(peerId.toString()); - this.missedMessagesByPeer.delete(peerId.toString()); - delete this.receivedMessagesHashes.nodes[peerId.toString()]; + log.warn( + `Failed to renew and subscribe peer ${peerId.toString()}: ${error}.` + ); + return undefined; } } @@ -402,16 +334,6 @@ export class SubscriptionManager implements ISubscriptionSDK { clearInterval(this.keepAliveTimer); this.keepAliveTimer = null; } - - private incrementMissedMessageCount(peerIdStr: string): void { - const currentCount = this.missedMessagesByPeer.get(peerIdStr) || 0; - this.missedMessagesByPeer.set(peerIdStr, currentCount + 1); - } - - private shouldRenewPeer(peerIdStr: string): boolean { - const missedMessages = this.missedMessagesByPeer.get(peerIdStr) || 0; - return missedMessages > this.maxMissedMessagesThreshold; - } } class FilterSDK extends BaseProtocolSDK implements IFilterSDK { diff --git a/packages/sdk/src/protocols/reliability_monitor.ts b/packages/sdk/src/protocols/reliability_monitor.ts new file mode 100644 index 0000000000..67b911a476 --- /dev/null +++ b/packages/sdk/src/protocols/reliability_monitor.ts @@ -0,0 +1,78 @@ +import type { Peer, PeerId } from "@libp2p/interface"; +import { IProtoMessage, PeerIdStr } from "@waku/interfaces"; +import { messageHashStr } from "@waku/message-hash"; +import { type WakuMessage } from "@waku/proto"; +import { Logger } from "@waku/utils"; + +const DEFAULT_MAX_MISSED_MESSAGES_THRESHOLD = 3; + +const log = new Logger("waku:message-monitor"); + +export class ReliabilityMonitor { + private readonly receivedMessagesHashes: Set = new Set(); + private readonly messageHashesByPeer: Map> = new Map(); + private readonly missedMessagesByPeer: Map = new Map(); + private readonly maxMissedMessagesThreshold: number; + + public constructor( + private readonly renewAndSubscribePeer: ( + peerToDisconnect: PeerId + ) => Promise + ) { + this.maxMissedMessagesThreshold = DEFAULT_MAX_MISSED_MESSAGES_THRESHOLD; + } + + public async processAndValidateMessage( + message: WakuMessage, + pubsubTopic: string, + peerIdStr: PeerIdStr, + getPeers: () => Peer[] + ): Promise { + const hash = messageHashStr(pubsubTopic, message as IProtoMessage); + this.addMessage(hash, peerIdStr); + await this.validateMessages(getPeers); + } + + public resetPeer(peerIdStr: PeerIdStr): void { + this.messageHashesByPeer.delete(peerIdStr); + this.missedMessagesByPeer.delete(peerIdStr); + } + + private addMessage(hash: string, peerIdStr: PeerIdStr): void { + this.receivedMessagesHashes.add(hash); + if (!this.messageHashesByPeer.has(peerIdStr)) { + this.messageHashesByPeer.set(peerIdStr, new Set()); + } + this.messageHashesByPeer.get(peerIdStr)!.add(hash); + } + + private async validateMessages(getPeers: () => Peer[]): Promise { + const peersToRenew: PeerIdStr[] = []; + for (const [peerIdStr, hashes] of this.messageHashesByPeer.entries()) { + const missedMessages = [...this.receivedMessagesHashes].filter( + (hash) => !hashes.has(hash) + ).length; + this.missedMessagesByPeer.set(peerIdStr, missedMessages); + if (missedMessages > this.maxMissedMessagesThreshold) { + peersToRenew.push(peerIdStr); + } + } + + for (const peerIdStr of peersToRenew) { + const peerId = getPeers().find((p) => p.id.toString() === peerIdStr)?.id; + if (peerId) { + try { + const newPeer = await this.renewAndSubscribePeer(peerId); + if (newPeer) { + this.resetPeer(peerIdStr); + this.resetPeer(newPeer.id.toString()); + } + } catch (error) { + log.error( + `Failed to renew and subscribe peer ${peerIdStr}: ${error}` + ); + } + } + } + } +}