diff --git a/packages/sdk/src/reliable_channel/reliable_channel.ts b/packages/sdk/src/reliable_channel/reliable_channel.ts index 49b55aa495..3ed2d6d549 100644 --- a/packages/sdk/src/reliable_channel/reliable_channel.ts +++ b/packages/sdk/src/reliable_channel/reliable_channel.ts @@ -17,6 +17,7 @@ import { isContentMessage, MessageChannel, MessageChannelEvent, + MessageChannelEvents, type MessageChannelOptions, Message as SdsMessage, type SenderId, @@ -136,11 +137,16 @@ export class ReliableChannel< callback: Callback ) => Promise; + private readonly _unsubscribe?: ( + decoders: IDecoder | IDecoder[] + ) => Promise; + private readonly _retrieve?: ( decoders: IDecoder[], options?: Partial ) => AsyncGenerator[]>; + private eventListenerCleanups: Array<() => void> = []; private readonly syncMinIntervalMs: number; private syncTimeout: ReturnType | undefined; private sweepInBufInterval: ReturnType | undefined; @@ -151,6 +157,7 @@ export class ReliableChannel< private readonly queryOnConnect?: QueryOnConnect; private readonly processTaskMinElapseMs: number; private _started: boolean; + private activePendingProcessTask?: Promise; private constructor( public node: IWaku, @@ -170,6 +177,7 @@ export class ReliableChannel< if (node.filter) { this._subscribe = node.filter.subscribe.bind(node.filter); + this._unsubscribe = node.filter.unsubscribe.bind(node.filter); } else if (node.relay) { // TODO: Why do relay and filter have different interfaces? // this._subscribe = node.relay.subscribeWithUnsubscribe; @@ -384,10 +392,21 @@ export class ReliableChannel< private async subscribe(): Promise { this.assertStarted(); return this._subscribe(this.decoder, async (message: T) => { + if (!this._started) { + log.info("ReliableChannel stopped, ignoring incoming message"); + return; + } await this.processIncomingMessage(message); }); } + private async unsubscribe(): Promise { + if (!this._unsubscribe) { + throw Error("No unsubscribe method available"); + } + return await this._unsubscribe(this.decoder); + } + /** * Don't forget to call `this.messageChannel.sweepIncomingBuffer();` once done. * @param msg @@ -458,12 +477,19 @@ export class ReliableChannel< // TODO: For now we only queue process tasks for incoming messages // As this is where there is most volume private queueProcessTasks(): void { + if (!this._started) return; + // If one is already queued, then we can ignore it if (this.processTaskTimeout === undefined) { this.processTaskTimeout = setTimeout(() => { - void this.messageChannel.processTasks().catch((err) => { - log.error("error encountered when processing sds tasks", err); - }); + this.activePendingProcessTask = this.messageChannel + .processTasks() + .catch((err) => { + log.error("error encountered when processing sds tasks", err); + }) + .finally(() => { + this.activePendingProcessTask = undefined; + }); // Clear timeout once triggered clearTimeout(this.processTaskTimeout); @@ -485,15 +511,41 @@ export class ReliableChannel< return this.subscribe(); } - public stop(): void { + public async stop(): Promise { if (!this._started) return; + + log.info("Stopping ReliableChannel..."); this._started = false; + this.stopSync(); this.stopSweepIncomingBufferLoop(); - this.missingMessageRetriever?.stop(); - this.queryOnConnect?.stop(); - // TODO unsubscribe - // TODO unsetMessageListeners + + if (this.processTaskTimeout) { + clearTimeout(this.processTaskTimeout); + this.processTaskTimeout = undefined; + } + + if (this.activePendingProcessTask) { + await this.activePendingProcessTask; + } + + if (this.missingMessageRetriever) { + await this.missingMessageRetriever.stop(); + } + + if (this.queryOnConnect) { + this.queryOnConnect.stop(); + } + + if (this.retryManager) { + this.retryManager.stopAllRetries(); + } + + await this.unsubscribe(); + + this.removeAllEventListeners(); + + log.info("ReliableChannel stopped successfully"); } private assertStarted(): void { @@ -509,12 +561,16 @@ export class ReliableChannel< } private stopSweepIncomingBufferLoop(): void { - if (this.sweepInBufInterval) clearInterval(this.sweepInBufInterval); + if (this.sweepInBufInterval) { + clearInterval(this.sweepInBufInterval); + this.sweepInBufInterval = undefined; + } } private restartSync(multiplier: number = 1): void { if (this.syncTimeout) { clearTimeout(this.syncTimeout); + this.syncTimeout = undefined; } if (this.syncMinIntervalMs) { const timeoutMs = this.random() * this.syncMinIntervalMs * multiplier; @@ -531,6 +587,7 @@ export class ReliableChannel< private stopSync(): void { if (this.syncTimeout) { clearTimeout(this.syncTimeout); + this.syncTimeout = undefined; } } @@ -595,8 +652,19 @@ export class ReliableChannel< return sdsMessage.causalHistory && sdsMessage.causalHistory.length > 0; } + private addTrackedEventListener( + eventName: K, + listener: (event: MessageChannelEvents[K]) => void + ): void { + this.messageChannel.addEventListener(eventName, listener as any); + + this.eventListenerCleanups.push(() => { + this.messageChannel.removeEventListener(eventName, listener as any); + }); + } + private setupEventListeners(): void { - this.messageChannel.addEventListener( + this.addTrackedEventListener( MessageChannelEvent.OutMessageSent, (event) => { if (event.detail.content) { @@ -608,7 +676,7 @@ export class ReliableChannel< } ); - this.messageChannel.addEventListener( + this.addTrackedEventListener( MessageChannelEvent.OutMessageAcknowledged, (event) => { if (event.detail) { @@ -622,7 +690,7 @@ export class ReliableChannel< } ); - this.messageChannel.addEventListener( + this.addTrackedEventListener( MessageChannelEvent.OutMessagePossiblyAcknowledged, (event) => { if (event.detail) { @@ -636,7 +704,7 @@ export class ReliableChannel< } ); - this.messageChannel.addEventListener( + this.addTrackedEventListener( MessageChannelEvent.InSyncReceived, (_event) => { // restart the timeout when a sync message has been received @@ -644,7 +712,7 @@ export class ReliableChannel< } ); - this.messageChannel.addEventListener( + this.addTrackedEventListener( MessageChannelEvent.InMessageReceived, (event) => { // restart the timeout when a content message has been received @@ -655,7 +723,7 @@ export class ReliableChannel< } ); - this.messageChannel.addEventListener( + this.addTrackedEventListener( MessageChannelEvent.OutMessageSent, (event) => { // restart the timeout when a content message has been sent @@ -665,7 +733,7 @@ export class ReliableChannel< } ); - this.messageChannel.addEventListener( + this.addTrackedEventListener( MessageChannelEvent.InMessageMissing, (event) => { for (const { messageId, retrievalHint } of event.detail) { @@ -680,12 +748,32 @@ export class ReliableChannel< ); if (this.queryOnConnect) { + const queryListener = (event: any): void => { + void this.processIncomingMessages(event.detail); + }; + this.queryOnConnect.addEventListener( QueryOnConnectEvent.MessagesRetrieved, - (event) => { - void this.processIncomingMessages(event.detail); - } + queryListener ); + + this.eventListenerCleanups.push(() => { + this.queryOnConnect?.removeEventListener( + QueryOnConnectEvent.MessagesRetrieved, + queryListener + ); + }); } } + + private removeAllEventListeners(): void { + for (const cleanup of this.eventListenerCleanups) { + try { + cleanup(); + } catch (error) { + log.error("error removing event listener:", error); + } + } + this.eventListenerCleanups = []; + } }